aws_config/imds/
client.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Raw IMDSv2 Client
7//!
8//! Client for direct access to IMDSv2.
9
10use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode};
11use crate::imds::client::token::TokenRuntimePlugin;
12use crate::provider_config::ProviderConfig;
13use crate::PKG_VERSION;
14use aws_runtime::user_agent::{ApiMetadata, AwsUserAgent, UserAgentInterceptor};
15use aws_smithy_runtime::client::orchestrator::operation::Operation;
16use aws_smithy_runtime::client::retries::strategy::StandardRetryStrategy;
17use aws_smithy_runtime_api::box_error::BoxError;
18use aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams;
19use aws_smithy_runtime_api::client::endpoint::{
20    EndpointFuture, EndpointResolverParams, ResolveEndpoint,
21};
22use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
23use aws_smithy_runtime_api::client::orchestrator::{
24    HttpRequest, OrchestratorError, SensitiveOutput,
25};
26use aws_smithy_runtime_api::client::result::ConnectorError;
27use aws_smithy_runtime_api::client::result::SdkError;
28use aws_smithy_runtime_api::client::retries::classifiers::{
29    ClassifyRetry, RetryAction, SharedRetryClassifier,
30};
31use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
32use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
33use aws_smithy_types::body::SdkBody;
34use aws_smithy_types::config_bag::{FrozenLayer, Layer};
35use aws_smithy_types::endpoint::Endpoint;
36use aws_smithy_types::retry::RetryConfig;
37use aws_smithy_types::timeout::TimeoutConfig;
38use aws_types::os_shim_internal::Env;
39use http::Uri;
40use std::borrow::Cow;
41use std::error::Error as _;
42use std::fmt;
43use std::str::FromStr;
44use std::sync::Arc;
45use std::time::Duration;
46
47pub mod error;
48mod token;
49
50// 6 hours
51const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
52const DEFAULT_ATTEMPTS: u32 = 4;
53const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
54const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
55const DEFAULT_OPERATION_TIMEOUT: Duration = Duration::from_secs(30);
56const DEFAULT_OPERATION_ATTEMPT_TIMEOUT: Duration = Duration::from_secs(10);
57
58fn user_agent() -> AwsUserAgent {
59    AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
60}
61
62/// IMDSv2 Client
63///
64/// Client for IMDSv2. This client handles fetching tokens, retrying on failure, and token
65/// caching according to the specified token TTL.
66///
67/// _Note: This client ONLY supports IMDSv2. It will not fallback to IMDSv1. See
68/// [transitioning to IMDSv2](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html#instance-metadata-transition-to-version-2)
69/// for more information._
70///
71/// **Note**: When running in a Docker container, all network requests will incur an additional hop. When combined with the default IMDS hop limit of 1, this will cause requests to IMDS to timeout! To fix this issue, you'll need to set the following instance metadata settings :
72/// ```txt
73/// amazonec2-metadata-token=required
74/// amazonec2-metadata-token-response-hop-limit=2
75/// ```
76///
77/// On an instance that is already running, these can be set with [ModifyInstanceMetadataOptions](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_ModifyInstanceMetadataOptions.html). On a new instance, these can be set with the `MetadataOptions` field on [RunInstances](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_RunInstances.html).
78///
79/// For more information about IMDSv2 vs. IMDSv1 see [this guide](https://docs.aws.amazon.com/AWSEC2/latest/WindowsGuide/configuring-instance-metadata-service.html)
80///
81/// # Client Configuration
82/// The IMDS client can load configuration explicitly, via environment variables, or via
83/// `~/.aws/config`. It will first attempt to resolve an endpoint override. If no endpoint
84/// override exists, it will attempt to resolve an [`EndpointMode`]. If no
85/// [`EndpointMode`] override exists, it will fallback to [`IpV4`](EndpointMode::IpV4). An exhaustive
86/// list is below:
87///
88/// ## Endpoint configuration list
89/// 1. Explicit configuration of `Endpoint` via the [builder](Builder):
90/// ```no_run
91/// use aws_config::imds::client::Client;
92/// # async fn docs() {
93/// let client = Client::builder()
94///   .endpoint("http://customimds:456/").expect("valid URI")
95///   .build();
96/// # }
97/// ```
98///
99/// 2. The `AWS_EC2_METADATA_SERVICE_ENDPOINT` environment variable. Note: If this environment variable
100///    is set, it MUST contain a valid URI or client construction will fail.
101///
102/// 3. The `ec2_metadata_service_endpoint` field in `~/.aws/config`:
103/// ```ini
104/// [default]
105/// # ... other configuration
106/// ec2_metadata_service_endpoint = http://my-custom-endpoint:444
107/// ```
108///
109/// 4. An explicitly set endpoint mode:
110/// ```no_run
111/// use aws_config::imds::client::{Client, EndpointMode};
112/// # async fn docs() {
113/// let client = Client::builder().endpoint_mode(EndpointMode::IpV6).build();
114/// # }
115/// ```
116///
117/// 5. An [endpoint mode](EndpointMode) loaded from the `AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE` environment
118///    variable. Valid values: `IPv4`, `IPv6`
119///
120/// 6. An [endpoint mode](EndpointMode) loaded from the `ec2_metadata_service_endpoint_mode` field in
121///    `~/.aws/config`:
122/// ```ini
123/// [default]
124/// # ... other configuration
125/// ec2_metadata_service_endpoint_mode = IPv4
126/// ```
127///
128/// 7. The default value of `http://169.254.169.254` will be used.
129///
130#[derive(Clone, Debug)]
131pub struct Client {
132    operation: Operation<String, SensitiveString, InnerImdsError>,
133}
134
135impl Client {
136    /// IMDS client builder
137    pub fn builder() -> Builder {
138        Builder::default()
139    }
140
141    /// Retrieve information from IMDS
142    ///
143    /// This method will handle loading and caching a session token, combining the `path` with the
144    /// configured IMDS endpoint, and retrying potential errors.
145    ///
146    /// For more information about IMDSv2 methods and functionality, see
147    /// [Instance metadata and user data](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html)
148    ///
149    /// # Examples
150    ///
151    /// ```no_run
152    /// use aws_config::imds::client::Client;
153    /// # async fn docs() {
154    /// let client = Client::builder().build();
155    /// let ami_id = client
156    ///   .get("/latest/meta-data/ami-id")
157    ///   .await
158    ///   .expect("failure communicating with IMDS");
159    /// # }
160    /// ```
161    pub async fn get(&self, path: impl Into<String>) -> Result<SensitiveString, ImdsError> {
162        self.operation
163            .invoke(path.into())
164            .await
165            .map_err(|err| match err {
166                SdkError::ConstructionFailure(_) if err.source().is_some() => {
167                    match err.into_source().map(|e| e.downcast::<ImdsError>()) {
168                        Ok(Ok(token_failure)) => *token_failure,
169                        Ok(Err(err)) => ImdsError::unexpected(err),
170                        Err(err) => ImdsError::unexpected(err),
171                    }
172                }
173                SdkError::ConstructionFailure(_) => ImdsError::unexpected(err),
174                SdkError::ServiceError(context) => match context.err() {
175                    InnerImdsError::InvalidUtf8 => {
176                        ImdsError::unexpected("IMDS returned invalid UTF-8")
177                    }
178                    InnerImdsError::BadStatus => ImdsError::error_response(context.into_raw()),
179                },
180                // If the error source is an ImdsError, then we need to directly return that source.
181                // That way, the IMDS token provider's errors can become the top-level ImdsError.
182                // There is a unit test that checks the correct error is being extracted.
183                err @ SdkError::DispatchFailure(_) => match err.into_source() {
184                    Ok(source) => match source.downcast::<ConnectorError>() {
185                        Ok(source) => match source.into_source().downcast::<ImdsError>() {
186                            Ok(source) => *source,
187                            Err(err) => ImdsError::unexpected(err),
188                        },
189                        Err(err) => ImdsError::unexpected(err),
190                    },
191                    Err(err) => ImdsError::unexpected(err),
192                },
193                SdkError::TimeoutError(_) | SdkError::ResponseError(_) => ImdsError::io_error(err),
194                _ => ImdsError::unexpected(err),
195            })
196    }
197}
198
199/// New-type around `String` that doesn't emit the string value in the `Debug` impl.
200#[derive(Clone)]
201pub struct SensitiveString(String);
202
203impl fmt::Debug for SensitiveString {
204    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205        f.debug_tuple("SensitiveString")
206            .field(&"** redacted **")
207            .finish()
208    }
209}
210
211impl AsRef<str> for SensitiveString {
212    fn as_ref(&self) -> &str {
213        &self.0
214    }
215}
216
217impl From<String> for SensitiveString {
218    fn from(value: String) -> Self {
219        Self(value)
220    }
221}
222
223impl From<SensitiveString> for String {
224    fn from(value: SensitiveString) -> Self {
225        value.0
226    }
227}
228
229/// Runtime plugin that is used by both the IMDS client and the inner client that resolves
230/// the IMDS token and attaches it to requests. This runtime plugin marks the responses as
231/// sensitive, configures user agent headers, and sets up retries and timeouts.
232#[derive(Debug)]
233struct ImdsCommonRuntimePlugin {
234    config: FrozenLayer,
235    components: RuntimeComponentsBuilder,
236}
237
238impl ImdsCommonRuntimePlugin {
239    fn new(
240        config: &ProviderConfig,
241        endpoint_resolver: ImdsEndpointResolver,
242        retry_config: RetryConfig,
243        retry_classifier: SharedRetryClassifier,
244        timeout_config: TimeoutConfig,
245    ) -> Self {
246        let mut layer = Layer::new("ImdsCommonRuntimePlugin");
247        layer.store_put(AuthSchemeOptionResolverParams::new(()));
248        layer.store_put(EndpointResolverParams::new(()));
249        layer.store_put(SensitiveOutput);
250        layer.store_put(retry_config);
251        layer.store_put(timeout_config);
252        layer.store_put(user_agent());
253
254        Self {
255            config: layer.freeze(),
256            components: RuntimeComponentsBuilder::new("ImdsCommonRuntimePlugin")
257                .with_http_client(config.http_client())
258                .with_endpoint_resolver(Some(endpoint_resolver))
259                .with_interceptor(UserAgentInterceptor::new())
260                .with_retry_classifier(retry_classifier)
261                .with_retry_strategy(Some(StandardRetryStrategy::new()))
262                .with_time_source(Some(config.time_source()))
263                .with_sleep_impl(config.sleep_impl()),
264        }
265    }
266}
267
268impl RuntimePlugin for ImdsCommonRuntimePlugin {
269    fn config(&self) -> Option<FrozenLayer> {
270        Some(self.config.clone())
271    }
272
273    fn runtime_components(
274        &self,
275        _current_components: &RuntimeComponentsBuilder,
276    ) -> Cow<'_, RuntimeComponentsBuilder> {
277        Cow::Borrowed(&self.components)
278    }
279}
280
281/// IMDSv2 Endpoint Mode
282///
283/// IMDS can be accessed in two ways:
284/// 1. Via the IpV4 endpoint: `http://169.254.169.254`
285/// 2. Via the Ipv6 endpoint: `http://[fd00:ec2::254]`
286#[derive(Debug, Clone)]
287#[non_exhaustive]
288pub enum EndpointMode {
289    /// IpV4 mode: `http://169.254.169.254`
290    ///
291    /// This mode is the default unless otherwise specified.
292    IpV4,
293    /// IpV6 mode: `http://[fd00:ec2::254]`
294    IpV6,
295}
296
297impl FromStr for EndpointMode {
298    type Err = InvalidEndpointMode;
299
300    fn from_str(value: &str) -> Result<Self, Self::Err> {
301        match value {
302            _ if value.eq_ignore_ascii_case("ipv4") => Ok(EndpointMode::IpV4),
303            _ if value.eq_ignore_ascii_case("ipv6") => Ok(EndpointMode::IpV6),
304            other => Err(InvalidEndpointMode::new(other.to_owned())),
305        }
306    }
307}
308
309impl EndpointMode {
310    /// IMDS URI for this endpoint mode
311    fn endpoint(&self) -> Uri {
312        match self {
313            EndpointMode::IpV4 => Uri::from_static("http://169.254.169.254"),
314            EndpointMode::IpV6 => Uri::from_static("http://[fd00:ec2::254]"),
315        }
316    }
317}
318
319/// IMDSv2 Client Builder
320#[derive(Default, Debug, Clone)]
321pub struct Builder {
322    max_attempts: Option<u32>,
323    endpoint: Option<EndpointSource>,
324    mode_override: Option<EndpointMode>,
325    token_ttl: Option<Duration>,
326    connect_timeout: Option<Duration>,
327    read_timeout: Option<Duration>,
328    operation_timeout: Option<Duration>,
329    operation_attempt_timeout: Option<Duration>,
330    config: Option<ProviderConfig>,
331    retry_classifier: Option<SharedRetryClassifier>,
332}
333
334impl Builder {
335    /// Override the number of retries for fetching tokens & metadata
336    ///
337    /// By default, 4 attempts will be made.
338    pub fn max_attempts(mut self, max_attempts: u32) -> Self {
339        self.max_attempts = Some(max_attempts);
340        self
341    }
342
343    /// Configure generic options of the [`Client`]
344    ///
345    /// # Examples
346    /// ```no_run
347    /// # async fn test() {
348    /// use aws_config::imds::Client;
349    /// use aws_config::provider_config::ProviderConfig;
350    ///
351    /// let provider = Client::builder()
352    ///     .configure(&ProviderConfig::with_default_region().await)
353    ///     .build();
354    /// # }
355    /// ```
356    pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
357        self.config = Some(provider_config.clone());
358        self
359    }
360
361    /// Override the endpoint for the [`Client`]
362    ///
363    /// By default, the client will resolve an endpoint from the environment, AWS config, and endpoint mode.
364    ///
365    /// See [`Client`] for more information.
366    pub fn endpoint(mut self, endpoint: impl AsRef<str>) -> Result<Self, BoxError> {
367        let uri: Uri = endpoint.as_ref().parse()?;
368        self.endpoint = Some(EndpointSource::Explicit(uri));
369        Ok(self)
370    }
371
372    /// Override the endpoint mode for [`Client`]
373    ///
374    /// * When set to [`IpV4`](EndpointMode::IpV4), the endpoint will be `http://169.254.169.254`.
375    /// * When set to [`IpV6`](EndpointMode::IpV6), the endpoint will be `http://[fd00:ec2::254]`.
376    pub fn endpoint_mode(mut self, mode: EndpointMode) -> Self {
377        self.mode_override = Some(mode);
378        self
379    }
380
381    /// Override the time-to-live for the session token
382    ///
383    /// Requests to IMDS utilize a session token for authentication. By default, session tokens last
384    /// for 6 hours. When the TTL for the token expires, a new token must be retrieved from the
385    /// metadata service.
386    pub fn token_ttl(mut self, ttl: Duration) -> Self {
387        self.token_ttl = Some(ttl);
388        self
389    }
390
391    /// Override the connect timeout for IMDS
392    ///
393    /// This value defaults to 1 second
394    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
395        self.connect_timeout = Some(timeout);
396        self
397    }
398
399    /// Override the read timeout for IMDS
400    ///
401    /// This value defaults to 1 second
402    pub fn read_timeout(mut self, timeout: Duration) -> Self {
403        self.read_timeout = Some(timeout);
404        self
405    }
406
407    /// Override the operation timeout for IMDS
408    ///
409    /// This value defaults to 1 second
410    pub fn operation_timeout(mut self, timeout: Duration) -> Self {
411        self.operation_timeout = Some(timeout);
412        self
413    }
414
415    /// Override the operation attempt timeout for IMDS
416    ///
417    /// This value defaults to 1 second
418    pub fn operation_attempt_timeout(mut self, timeout: Duration) -> Self {
419        self.operation_attempt_timeout = Some(timeout);
420        self
421    }
422
423    /// Override the retry classifier for IMDS
424    ///
425    /// This defaults to only retrying on server errors and 401s. The [ImdsResponseRetryClassifier] in this
426    /// module offers some configuration options and can be wrapped by[SharedRetryClassifier::new()] for use
427    /// here or you can create your own fully customized [SharedRetryClassifier].
428    pub fn retry_classifier(mut self, retry_classifier: SharedRetryClassifier) -> Self {
429        self.retry_classifier = Some(retry_classifier);
430        self
431    }
432
433    /* TODO(https://github.com/awslabs/aws-sdk-rust/issues/339): Support customizing the port explicitly */
434    /*
435    pub fn port(mut self, port: u32) -> Self {
436        self.port_override = Some(port);
437        self
438    }*/
439
440    /// Build an IMDSv2 Client
441    pub fn build(self) -> Client {
442        let config = self.config.unwrap_or_default();
443        let timeout_config = TimeoutConfig::builder()
444            .connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
445            .read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT))
446            .operation_attempt_timeout(
447                self.operation_attempt_timeout
448                    .unwrap_or(DEFAULT_OPERATION_ATTEMPT_TIMEOUT),
449            )
450            .operation_timeout(self.operation_timeout.unwrap_or(DEFAULT_OPERATION_TIMEOUT))
451            .build();
452        let endpoint_source = self
453            .endpoint
454            .unwrap_or_else(|| EndpointSource::Env(config.clone()));
455        let endpoint_resolver = ImdsEndpointResolver {
456            endpoint_source: Arc::new(endpoint_source),
457            mode_override: self.mode_override,
458        };
459        let retry_config = RetryConfig::standard()
460            .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
461        let retry_classifier = self.retry_classifier.unwrap_or(SharedRetryClassifier::new(
462            ImdsResponseRetryClassifier::default(),
463        ));
464        let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new(
465            &config,
466            endpoint_resolver,
467            retry_config,
468            retry_classifier,
469            timeout_config,
470        ));
471        let operation = Operation::builder()
472            .service_name("imds")
473            .operation_name("get")
474            .runtime_plugin(common_plugin.clone())
475            .runtime_plugin(TokenRuntimePlugin::new(
476                common_plugin,
477                self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
478            ))
479            .with_connection_poisoning()
480            .serializer(|path| {
481                Ok(HttpRequest::try_from(
482                    http::Request::builder()
483                        .uri(path)
484                        .body(SdkBody::empty())
485                        .expect("valid request"),
486                )
487                .unwrap())
488            })
489            .deserializer(|response| {
490                if response.status().is_success() {
491                    std::str::from_utf8(response.body().bytes().expect("non-streaming response"))
492                        .map(|data| SensitiveString::from(data.to_string()))
493                        .map_err(|_| OrchestratorError::operation(InnerImdsError::InvalidUtf8))
494                } else {
495                    Err(OrchestratorError::operation(InnerImdsError::BadStatus))
496                }
497            })
498            .build();
499        Client { operation }
500    }
501}
502
503mod env {
504    pub(super) const ENDPOINT: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT";
505    pub(super) const ENDPOINT_MODE: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE";
506}
507
508mod profile_keys {
509    pub(super) const ENDPOINT: &str = "ec2_metadata_service_endpoint";
510    pub(super) const ENDPOINT_MODE: &str = "ec2_metadata_service_endpoint_mode";
511}
512
513/// Endpoint Configuration Abstraction
514#[derive(Debug, Clone)]
515enum EndpointSource {
516    Explicit(Uri),
517    Env(ProviderConfig),
518}
519
520impl EndpointSource {
521    async fn endpoint(&self, mode_override: Option<EndpointMode>) -> Result<Uri, BuildError> {
522        match self {
523            EndpointSource::Explicit(uri) => {
524                if mode_override.is_some() {
525                    tracing::warn!(endpoint = ?uri, mode = ?mode_override,
526                        "Endpoint mode override was set in combination with an explicit endpoint. \
527                        The mode override will be ignored.")
528                }
529                Ok(uri.clone())
530            }
531            EndpointSource::Env(conf) => {
532                let env = conf.env();
533                // load an endpoint override from the environment
534                let profile = conf.profile().await;
535                let uri_override = if let Ok(uri) = env.get(env::ENDPOINT) {
536                    Some(Cow::Owned(uri))
537                } else {
538                    profile
539                        .and_then(|profile| profile.get(profile_keys::ENDPOINT))
540                        .map(Cow::Borrowed)
541                };
542                if let Some(uri) = uri_override {
543                    return Uri::try_from(uri.as_ref()).map_err(BuildError::invalid_endpoint_uri);
544                }
545
546                // if not, load a endpoint mode from the environment
547                let mode = if let Some(mode) = mode_override {
548                    mode
549                } else if let Ok(mode) = env.get(env::ENDPOINT_MODE) {
550                    mode.parse::<EndpointMode>()
551                        .map_err(BuildError::invalid_endpoint_mode)?
552                } else if let Some(mode) = profile.and_then(|p| p.get(profile_keys::ENDPOINT_MODE))
553                {
554                    mode.parse::<EndpointMode>()
555                        .map_err(BuildError::invalid_endpoint_mode)?
556                } else {
557                    EndpointMode::IpV4
558                };
559
560                Ok(mode.endpoint())
561            }
562        }
563    }
564}
565
566#[derive(Clone, Debug)]
567struct ImdsEndpointResolver {
568    endpoint_source: Arc<EndpointSource>,
569    mode_override: Option<EndpointMode>,
570}
571
572impl ResolveEndpoint for ImdsEndpointResolver {
573    fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a> {
574        EndpointFuture::new(async move {
575            self.endpoint_source
576                .endpoint(self.mode_override.clone())
577                .await
578                .map(|uri| Endpoint::builder().url(uri.to_string()).build())
579                .map_err(|err| err.into())
580        })
581    }
582}
583
584/// IMDS Response Retry Classifier
585///
586/// Possible status codes:
587/// - 200 (OK)
588/// - 400 (Missing or invalid parameters) **Not Retryable**
589/// - 401 (Unauthorized, expired token) **Retryable**
590/// - 403 (IMDS disabled): **Not Retryable**
591/// - 404 (Not found): **Not Retryable**
592/// - >=500 (server error): **Retryable**
593/// - Timeouts: Not retried by default, but this is configurable via [Self::with_retry_connect_timeouts()]
594#[derive(Clone, Debug, Default)]
595#[non_exhaustive]
596pub struct ImdsResponseRetryClassifier {
597    retry_connect_timeouts: bool,
598}
599
600impl ImdsResponseRetryClassifier {
601    /// Indicate whether the IMDS client should retry on connection timeouts
602    pub fn with_retry_connect_timeouts(mut self, retry_connect_timeouts: bool) -> Self {
603        self.retry_connect_timeouts = retry_connect_timeouts;
604        self
605    }
606}
607
608impl ClassifyRetry for ImdsResponseRetryClassifier {
609    fn name(&self) -> &'static str {
610        "ImdsResponseRetryClassifier"
611    }
612
613    fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
614        if let Some(response) = ctx.response() {
615            let status = response.status();
616            match status {
617                _ if status.is_server_error() => RetryAction::server_error(),
618                // 401 indicates that the token has expired, this is retryable
619                _ if status.as_u16() == 401 => RetryAction::server_error(),
620                // This catch-all includes successful responses that fail to parse. These should not be retried.
621                _ => RetryAction::NoActionIndicated,
622            }
623        } else if self.retry_connect_timeouts {
624            RetryAction::server_error()
625        } else {
626            // This is the default behavior.
627            // Don't retry timeouts for IMDS, or else it will take ~30 seconds for the default
628            // credentials provider chain to fail to provide credentials.
629            // Also don't retry non-responses.
630            RetryAction::NoActionIndicated
631        }
632    }
633}
634
635#[cfg(test)]
636pub(crate) mod test {
637    use crate::imds::client::{Client, EndpointMode, ImdsResponseRetryClassifier};
638    use crate::provider_config::ProviderConfig;
639    use aws_smithy_async::rt::sleep::TokioSleep;
640    use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep};
641    use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
642    use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
643    use aws_smithy_runtime_api::client::interceptors::context::{
644        Input, InterceptorContext, Output,
645    };
646    use aws_smithy_runtime_api::client::orchestrator::{
647        HttpRequest, HttpResponse, OrchestratorError,
648    };
649    use aws_smithy_runtime_api::client::result::ConnectorError;
650    use aws_smithy_runtime_api::client::retries::classifiers::{
651        ClassifyRetry, RetryAction, SharedRetryClassifier,
652    };
653    use aws_smithy_types::body::SdkBody;
654    use aws_smithy_types::error::display::DisplayErrorContext;
655    use aws_types::os_shim_internal::{Env, Fs};
656    use http::header::USER_AGENT;
657    use http::Uri;
658    use serde::Deserialize;
659    use std::collections::HashMap;
660    use std::error::Error;
661    use std::io;
662    use std::time::SystemTime;
663    use std::time::{Duration, UNIX_EPOCH};
664    use tracing_test::traced_test;
665
666    macro_rules! assert_full_error_contains {
667        ($err:expr, $contains:expr) => {
668            let err = $err;
669            let message = format!(
670                "{}",
671                aws_smithy_types::error::display::DisplayErrorContext(&err)
672            );
673            assert!(
674                message.contains($contains),
675                "Error message '{message}' didn't contain text '{}'",
676                $contains
677            );
678        };
679    }
680
681    const TOKEN_A: &str = "AQAEAFTNrA4eEGx0AQgJ1arIq_Cc-t4tWt3fB0Hd8RKhXlKc5ccvhg==";
682    const TOKEN_B: &str = "alternatetoken==";
683
684    pub(crate) fn token_request(base: &str, ttl: u32) -> HttpRequest {
685        http::Request::builder()
686            .uri(format!("{}/latest/api/token", base))
687            .header("x-aws-ec2-metadata-token-ttl-seconds", ttl)
688            .method("PUT")
689            .body(SdkBody::empty())
690            .unwrap()
691            .try_into()
692            .unwrap()
693    }
694
695    pub(crate) fn token_response(ttl: u32, token: &'static str) -> HttpResponse {
696        HttpResponse::try_from(
697            http::Response::builder()
698                .status(200)
699                .header("X-aws-ec2-metadata-token-ttl-seconds", ttl)
700                .body(SdkBody::from(token))
701                .unwrap(),
702        )
703        .unwrap()
704    }
705
706    pub(crate) fn imds_request(path: &'static str, token: &str) -> HttpRequest {
707        http::Request::builder()
708            .uri(Uri::from_static(path))
709            .method("GET")
710            .header("x-aws-ec2-metadata-token", token)
711            .body(SdkBody::empty())
712            .unwrap()
713            .try_into()
714            .unwrap()
715    }
716
717    pub(crate) fn imds_response(body: &'static str) -> HttpResponse {
718        HttpResponse::try_from(
719            http::Response::builder()
720                .status(200)
721                .body(SdkBody::from(body))
722                .unwrap(),
723        )
724        .unwrap()
725    }
726
727    pub(crate) fn make_imds_client(http_client: &StaticReplayClient) -> super::Client {
728        tokio::time::pause();
729        super::Client::builder()
730            .configure(
731                &ProviderConfig::no_configuration()
732                    .with_sleep_impl(InstantSleep::unlogged())
733                    .with_http_client(http_client.clone()),
734            )
735            .build()
736    }
737
738    fn mock_imds_client(events: Vec<ReplayEvent>) -> (Client, StaticReplayClient) {
739        let http_client = StaticReplayClient::new(events);
740        let client = make_imds_client(&http_client);
741        (client, http_client)
742    }
743
744    #[tokio::test]
745    async fn client_caches_token() {
746        let (client, http_client) = mock_imds_client(vec![
747            ReplayEvent::new(
748                token_request("http://169.254.169.254", 21600),
749                token_response(21600, TOKEN_A),
750            ),
751            ReplayEvent::new(
752                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
753                imds_response(r#"test-imds-output"#),
754            ),
755            ReplayEvent::new(
756                imds_request("http://169.254.169.254/latest/metadata2", TOKEN_A),
757                imds_response("output2"),
758            ),
759        ]);
760        // load once
761        let metadata = client.get("/latest/metadata").await.expect("failed");
762        assert_eq!("test-imds-output", metadata.as_ref());
763        // load again: the cached token should be used
764        let metadata = client.get("/latest/metadata2").await.expect("failed");
765        assert_eq!("output2", metadata.as_ref());
766        http_client.assert_requests_match(&[]);
767    }
768
769    #[tokio::test]
770    async fn token_can_expire() {
771        let (_, http_client) = mock_imds_client(vec![
772            ReplayEvent::new(
773                token_request("http://[fd00:ec2::254]", 600),
774                token_response(600, TOKEN_A),
775            ),
776            ReplayEvent::new(
777                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
778                imds_response(r#"test-imds-output1"#),
779            ),
780            ReplayEvent::new(
781                token_request("http://[fd00:ec2::254]", 600),
782                token_response(600, TOKEN_B),
783            ),
784            ReplayEvent::new(
785                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
786                imds_response(r#"test-imds-output2"#),
787            ),
788        ]);
789        let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
790        let client = super::Client::builder()
791            .configure(
792                &ProviderConfig::no_configuration()
793                    .with_http_client(http_client.clone())
794                    .with_time_source(time_source.clone())
795                    .with_sleep_impl(sleep),
796            )
797            .endpoint_mode(EndpointMode::IpV6)
798            .token_ttl(Duration::from_secs(600))
799            .build();
800
801        let resp1 = client.get("/latest/metadata").await.expect("success");
802        // now the cached credential has expired
803        time_source.advance(Duration::from_secs(600));
804        let resp2 = client.get("/latest/metadata").await.expect("success");
805        http_client.assert_requests_match(&[]);
806        assert_eq!("test-imds-output1", resp1.as_ref());
807        assert_eq!("test-imds-output2", resp2.as_ref());
808    }
809
810    /// Tokens are refreshed up to 120 seconds early to avoid using an expired token.
811    #[tokio::test]
812    async fn token_refresh_buffer() {
813        let _logs = capture_test_logs();
814        let (_, http_client) = mock_imds_client(vec![
815            ReplayEvent::new(
816                token_request("http://[fd00:ec2::254]", 600),
817                token_response(600, TOKEN_A),
818            ),
819            // t = 0
820            ReplayEvent::new(
821                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
822                imds_response(r#"test-imds-output1"#),
823            ),
824            // t = 400 (no refresh)
825            ReplayEvent::new(
826                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
827                imds_response(r#"test-imds-output2"#),
828            ),
829            // t = 550 (within buffer)
830            ReplayEvent::new(
831                token_request("http://[fd00:ec2::254]", 600),
832                token_response(600, TOKEN_B),
833            ),
834            ReplayEvent::new(
835                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
836                imds_response(r#"test-imds-output3"#),
837            ),
838        ]);
839        let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
840        let client = super::Client::builder()
841            .configure(
842                &ProviderConfig::no_configuration()
843                    .with_sleep_impl(sleep)
844                    .with_http_client(http_client.clone())
845                    .with_time_source(time_source.clone()),
846            )
847            .endpoint_mode(EndpointMode::IpV6)
848            .token_ttl(Duration::from_secs(600))
849            .build();
850
851        tracing::info!("resp1 -----------------------------------------------------------");
852        let resp1 = client.get("/latest/metadata").await.expect("success");
853        // now the cached credential has expired
854        time_source.advance(Duration::from_secs(400));
855        tracing::info!("resp2 -----------------------------------------------------------");
856        let resp2 = client.get("/latest/metadata").await.expect("success");
857        time_source.advance(Duration::from_secs(150));
858        tracing::info!("resp3 -----------------------------------------------------------");
859        let resp3 = client.get("/latest/metadata").await.expect("success");
860        http_client.assert_requests_match(&[]);
861        assert_eq!("test-imds-output1", resp1.as_ref());
862        assert_eq!("test-imds-output2", resp2.as_ref());
863        assert_eq!("test-imds-output3", resp3.as_ref());
864    }
865
866    /// 500 error during the GET should be retried
867    #[tokio::test]
868    #[traced_test]
869    async fn retry_500() {
870        let (client, http_client) = mock_imds_client(vec![
871            ReplayEvent::new(
872                token_request("http://169.254.169.254", 21600),
873                token_response(21600, TOKEN_A),
874            ),
875            ReplayEvent::new(
876                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
877                http::Response::builder()
878                    .status(500)
879                    .body(SdkBody::empty())
880                    .unwrap(),
881            ),
882            ReplayEvent::new(
883                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
884                imds_response("ok"),
885            ),
886        ]);
887        assert_eq!(
888            "ok",
889            client
890                .get("/latest/metadata")
891                .await
892                .expect("success")
893                .as_ref()
894        );
895        http_client.assert_requests_match(&[]);
896
897        // all requests should have a user agent header
898        for request in http_client.actual_requests() {
899            assert!(request.headers().get(USER_AGENT).is_some());
900        }
901    }
902
903    /// 500 error during token acquisition should be retried
904    #[tokio::test]
905    #[traced_test]
906    async fn retry_token_failure() {
907        let (client, http_client) = mock_imds_client(vec![
908            ReplayEvent::new(
909                token_request("http://169.254.169.254", 21600),
910                http::Response::builder()
911                    .status(500)
912                    .body(SdkBody::empty())
913                    .unwrap(),
914            ),
915            ReplayEvent::new(
916                token_request("http://169.254.169.254", 21600),
917                token_response(21600, TOKEN_A),
918            ),
919            ReplayEvent::new(
920                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
921                imds_response("ok"),
922            ),
923        ]);
924        assert_eq!(
925            "ok",
926            client
927                .get("/latest/metadata")
928                .await
929                .expect("success")
930                .as_ref()
931        );
932        http_client.assert_requests_match(&[]);
933    }
934
935    /// 401 error during metadata retrieval must be retried
936    #[tokio::test]
937    #[traced_test]
938    async fn retry_metadata_401() {
939        let (client, http_client) = mock_imds_client(vec![
940            ReplayEvent::new(
941                token_request("http://169.254.169.254", 21600),
942                token_response(0, TOKEN_A),
943            ),
944            ReplayEvent::new(
945                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
946                http::Response::builder()
947                    .status(401)
948                    .body(SdkBody::empty())
949                    .unwrap(),
950            ),
951            ReplayEvent::new(
952                token_request("http://169.254.169.254", 21600),
953                token_response(21600, TOKEN_B),
954            ),
955            ReplayEvent::new(
956                imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
957                imds_response("ok"),
958            ),
959        ]);
960        assert_eq!(
961            "ok",
962            client
963                .get("/latest/metadata")
964                .await
965                .expect("success")
966                .as_ref()
967        );
968        http_client.assert_requests_match(&[]);
969    }
970
971    /// 403 responses from IMDS during token acquisition MUST NOT be retried
972    #[tokio::test]
973    #[traced_test]
974    async fn no_403_retry() {
975        let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
976            token_request("http://169.254.169.254", 21600),
977            http::Response::builder()
978                .status(403)
979                .body(SdkBody::empty())
980                .unwrap(),
981        )]);
982        let err = client.get("/latest/metadata").await.expect_err("no token");
983        assert_full_error_contains!(err, "forbidden");
984        http_client.assert_requests_match(&[]);
985    }
986
987    /// The classifier should return `None` when classifying a successful response.
988    #[test]
989    fn successful_response_properly_classified() {
990        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
991        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
992        ctx.set_response(imds_response("").map(|_| SdkBody::empty()));
993        let classifier = ImdsResponseRetryClassifier::default();
994        assert_eq!(
995            RetryAction::NoActionIndicated,
996            classifier.classify_retry(&ctx)
997        );
998
999        // Emulate a failure to parse the response body (using an io error since it's easy to construct in a test)
1000        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
1001        ctx.set_output_or_error(Err(OrchestratorError::connector(ConnectorError::io(
1002            io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse").into(),
1003        ))));
1004        assert_eq!(
1005            RetryAction::NoActionIndicated,
1006            classifier.classify_retry(&ctx)
1007        );
1008    }
1009
1010    /// User provided retry classifier works
1011    #[tokio::test]
1012    async fn user_provided_retry_classifier() {
1013        #[derive(Clone, Debug)]
1014        struct UserProvidedRetryClassifier;
1015
1016        impl ClassifyRetry for UserProvidedRetryClassifier {
1017            fn name(&self) -> &'static str {
1018                "UserProvidedRetryClassifier"
1019            }
1020
1021            // Don't retry anything
1022            fn classify_retry(&self, _ctx: &InterceptorContext) -> RetryAction {
1023                RetryAction::RetryForbidden
1024            }
1025        }
1026
1027        let events = vec![
1028            ReplayEvent::new(
1029                token_request("http://169.254.169.254", 21600),
1030                token_response(0, TOKEN_A),
1031            ),
1032            ReplayEvent::new(
1033                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1034                http::Response::builder()
1035                    .status(401)
1036                    .body(SdkBody::empty())
1037                    .unwrap(),
1038            ),
1039            ReplayEvent::new(
1040                token_request("http://169.254.169.254", 21600),
1041                token_response(21600, TOKEN_B),
1042            ),
1043            ReplayEvent::new(
1044                imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
1045                imds_response("ok"),
1046            ),
1047        ];
1048        let http_client = StaticReplayClient::new(events);
1049
1050        let imds_client = super::Client::builder()
1051            .configure(
1052                &ProviderConfig::no_configuration()
1053                    .with_sleep_impl(InstantSleep::unlogged())
1054                    .with_http_client(http_client.clone()),
1055            )
1056            .retry_classifier(SharedRetryClassifier::new(UserProvidedRetryClassifier))
1057            .build();
1058
1059        let res = imds_client
1060            .get("/latest/metadata")
1061            .await
1062            .expect_err("Client should error");
1063
1064        // Assert that the operation errored on the initial 401 and did not retry and get
1065        // the 200 (since the user provided retry classifier never retries)
1066        assert_full_error_contains!(res, "401");
1067    }
1068
1069    // since tokens are sent as headers, the tokens need to be valid header values
1070    #[tokio::test]
1071    async fn invalid_token() {
1072        let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
1073            token_request("http://169.254.169.254", 21600),
1074            token_response(21600, "invalid\nheader\nvalue\0"),
1075        )]);
1076        let err = client.get("/latest/metadata").await.expect_err("no token");
1077        assert_full_error_contains!(err, "invalid token");
1078        http_client.assert_requests_match(&[]);
1079    }
1080
1081    #[tokio::test]
1082    async fn non_utf8_response() {
1083        let (client, http_client) = mock_imds_client(vec![
1084            ReplayEvent::new(
1085                token_request("http://169.254.169.254", 21600),
1086                token_response(21600, TOKEN_A).map(SdkBody::from),
1087            ),
1088            ReplayEvent::new(
1089                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1090                http::Response::builder()
1091                    .status(200)
1092                    .body(SdkBody::from(vec![0xA0, 0xA1]))
1093                    .unwrap(),
1094            ),
1095        ]);
1096        let err = client.get("/latest/metadata").await.expect_err("no token");
1097        assert_full_error_contains!(err, "invalid UTF-8");
1098        http_client.assert_requests_match(&[]);
1099    }
1100
1101    // TODO(https://github.com/awslabs/aws-sdk-rust/issues/1117) This test is ignored on Windows because it uses Unix-style paths
1102    #[cfg_attr(windows, ignore)]
1103    /// Verify that the end-to-end real client has a 1-second connect timeout
1104    #[tokio::test]
1105    #[cfg(feature = "default-https-client")]
1106    async fn one_second_connect_timeout() {
1107        use crate::imds::client::ImdsError;
1108        let client = Client::builder()
1109            // 240.* can never be resolved
1110            .endpoint("http://240.0.0.0")
1111            .expect("valid uri")
1112            .build();
1113        let now = SystemTime::now();
1114        let resp = client
1115            .get("/latest/metadata")
1116            .await
1117            .expect_err("240.0.0.0 will never resolve");
1118        match resp {
1119            err @ ImdsError::FailedToLoadToken(_)
1120                if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} // ok,
1121            other => panic!(
1122                "wrong error, expected construction failure with TimedOutError inside: {}",
1123                DisplayErrorContext(&other)
1124            ),
1125        }
1126        let time_elapsed = now.elapsed().unwrap();
1127        assert!(
1128            time_elapsed > Duration::from_secs(1),
1129            "time_elapsed should be greater than 1s but was {:?}",
1130            time_elapsed
1131        );
1132        assert!(
1133            time_elapsed < Duration::from_secs(2),
1134            "time_elapsed should be less than 2s but was {:?}",
1135            time_elapsed
1136        );
1137    }
1138
1139    /// Retry classifier properly retries timeouts when configured to (meaning it takes ~30s to fail)
1140    #[tokio::test]
1141    async fn retry_connect_timeouts() {
1142        let http_client = StaticReplayClient::new(vec![]);
1143        let imds_client = super::Client::builder()
1144            .retry_classifier(SharedRetryClassifier::new(
1145                ImdsResponseRetryClassifier::default().with_retry_connect_timeouts(true),
1146            ))
1147            .configure(&ProviderConfig::no_configuration().with_http_client(http_client.clone()))
1148            .operation_timeout(Duration::from_secs(1))
1149            .endpoint("http://240.0.0.0")
1150            .expect("valid uri")
1151            .build();
1152
1153        let now = SystemTime::now();
1154        let _res = imds_client
1155            .get("/latest/metadata")
1156            .await
1157            .expect_err("240.0.0.0 will never resolve");
1158        let time_elapsed: Duration = now.elapsed().unwrap();
1159
1160        assert!(
1161            time_elapsed > Duration::from_secs(1),
1162            "time_elapsed should be greater than 1s but was {:?}",
1163            time_elapsed
1164        );
1165
1166        assert!(
1167            time_elapsed < Duration::from_secs(2),
1168            "time_elapsed should be less than 2s but was {:?}",
1169            time_elapsed
1170        );
1171    }
1172
1173    #[derive(Debug, Deserialize)]
1174    struct ImdsConfigTest {
1175        env: HashMap<String, String>,
1176        fs: HashMap<String, String>,
1177        endpoint_override: Option<String>,
1178        mode_override: Option<String>,
1179        result: Result<String, String>,
1180        docs: String,
1181    }
1182
1183    #[tokio::test]
1184    async fn endpoint_config_tests() -> Result<(), Box<dyn Error>> {
1185        let _logs = capture_test_logs();
1186
1187        let test_cases = std::fs::read_to_string("test-data/imds-config/imds-endpoint-tests.json")?;
1188        #[derive(Deserialize)]
1189        struct TestCases {
1190            tests: Vec<ImdsConfigTest>,
1191        }
1192
1193        let test_cases: TestCases = serde_json::from_str(&test_cases)?;
1194        let test_cases = test_cases.tests;
1195        for test in test_cases {
1196            check(test).await;
1197        }
1198        Ok(())
1199    }
1200
1201    async fn check(test_case: ImdsConfigTest) {
1202        let (http_client, watcher) = capture_request(None);
1203        let provider_config = ProviderConfig::no_configuration()
1204            .with_sleep_impl(TokioSleep::new())
1205            .with_env(Env::from(test_case.env))
1206            .with_fs(Fs::from_map(test_case.fs))
1207            .with_http_client(http_client);
1208        let mut imds_client = Client::builder().configure(&provider_config);
1209        if let Some(endpoint_override) = test_case.endpoint_override {
1210            imds_client = imds_client
1211                .endpoint(endpoint_override)
1212                .expect("invalid URI");
1213        }
1214
1215        if let Some(mode_override) = test_case.mode_override {
1216            imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap());
1217        }
1218
1219        let imds_client = imds_client.build();
1220        match &test_case.result {
1221            Ok(uri) => {
1222                // this request will fail, we just want to capture the endpoint configuration
1223                let _ = imds_client.get("/hello").await;
1224                assert_eq!(&watcher.expect_request().uri().to_string(), uri);
1225            }
1226            Err(expected) => {
1227                let err = imds_client.get("/hello").await.expect_err("it should fail");
1228                let message = format!("{}", DisplayErrorContext(&err));
1229                assert!(
1230                    message.contains(expected),
1231                    "{}\nexpected error: {expected}\nactual error: {message}",
1232                    test_case.docs
1233                );
1234            }
1235        };
1236    }
1237}