1use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode};
11use crate::imds::client::token::TokenRuntimePlugin;
12use crate::provider_config::ProviderConfig;
13use crate::PKG_VERSION;
14use aws_runtime::user_agent::{ApiMetadata, AwsUserAgent, UserAgentInterceptor};
15use aws_smithy_runtime::client::orchestrator::operation::Operation;
16use aws_smithy_runtime::client::retries::strategy::StandardRetryStrategy;
17use aws_smithy_runtime_api::box_error::BoxError;
18use aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams;
19use aws_smithy_runtime_api::client::endpoint::{
20 EndpointFuture, EndpointResolverParams, ResolveEndpoint,
21};
22use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
23use aws_smithy_runtime_api::client::orchestrator::{
24 HttpRequest, OrchestratorError, SensitiveOutput,
25};
26use aws_smithy_runtime_api::client::result::ConnectorError;
27use aws_smithy_runtime_api::client::result::SdkError;
28use aws_smithy_runtime_api::client::retries::classifiers::{
29 ClassifyRetry, RetryAction, SharedRetryClassifier,
30};
31use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
32use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
33use aws_smithy_types::body::SdkBody;
34use aws_smithy_types::config_bag::{FrozenLayer, Layer};
35use aws_smithy_types::endpoint::Endpoint;
36use aws_smithy_types::retry::RetryConfig;
37use aws_smithy_types::timeout::TimeoutConfig;
38use aws_types::os_shim_internal::Env;
39use http::Uri;
40use std::borrow::Cow;
41use std::error::Error as _;
42use std::fmt;
43use std::str::FromStr;
44use std::sync::Arc;
45use std::time::Duration;
46
47pub mod error;
48mod token;
49
50const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
52const DEFAULT_ATTEMPTS: u32 = 4;
53const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
54const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
55const DEFAULT_OPERATION_TIMEOUT: Duration = Duration::from_secs(30);
56const DEFAULT_OPERATION_ATTEMPT_TIMEOUT: Duration = Duration::from_secs(10);
57
58fn user_agent() -> AwsUserAgent {
59 AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
60}
61
62#[derive(Clone, Debug)]
131pub struct Client {
132 operation: Operation<String, SensitiveString, InnerImdsError>,
133}
134
135impl Client {
136 pub fn builder() -> Builder {
138 Builder::default()
139 }
140
141 pub async fn get(&self, path: impl Into<String>) -> Result<SensitiveString, ImdsError> {
162 self.operation
163 .invoke(path.into())
164 .await
165 .map_err(|err| match err {
166 SdkError::ConstructionFailure(_) if err.source().is_some() => {
167 match err.into_source().map(|e| e.downcast::<ImdsError>()) {
168 Ok(Ok(token_failure)) => *token_failure,
169 Ok(Err(err)) => ImdsError::unexpected(err),
170 Err(err) => ImdsError::unexpected(err),
171 }
172 }
173 SdkError::ConstructionFailure(_) => ImdsError::unexpected(err),
174 SdkError::ServiceError(context) => match context.err() {
175 InnerImdsError::InvalidUtf8 => {
176 ImdsError::unexpected("IMDS returned invalid UTF-8")
177 }
178 InnerImdsError::BadStatus => ImdsError::error_response(context.into_raw()),
179 },
180 err @ SdkError::DispatchFailure(_) => match err.into_source() {
184 Ok(source) => match source.downcast::<ConnectorError>() {
185 Ok(source) => match source.into_source().downcast::<ImdsError>() {
186 Ok(source) => *source,
187 Err(err) => ImdsError::unexpected(err),
188 },
189 Err(err) => ImdsError::unexpected(err),
190 },
191 Err(err) => ImdsError::unexpected(err),
192 },
193 SdkError::TimeoutError(_) | SdkError::ResponseError(_) => ImdsError::io_error(err),
194 _ => ImdsError::unexpected(err),
195 })
196 }
197}
198
199#[derive(Clone)]
201pub struct SensitiveString(String);
202
203impl fmt::Debug for SensitiveString {
204 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205 f.debug_tuple("SensitiveString")
206 .field(&"** redacted **")
207 .finish()
208 }
209}
210
211impl AsRef<str> for SensitiveString {
212 fn as_ref(&self) -> &str {
213 &self.0
214 }
215}
216
217impl From<String> for SensitiveString {
218 fn from(value: String) -> Self {
219 Self(value)
220 }
221}
222
223impl From<SensitiveString> for String {
224 fn from(value: SensitiveString) -> Self {
225 value.0
226 }
227}
228
229#[derive(Debug)]
233struct ImdsCommonRuntimePlugin {
234 config: FrozenLayer,
235 components: RuntimeComponentsBuilder,
236}
237
238impl ImdsCommonRuntimePlugin {
239 fn new(
240 config: &ProviderConfig,
241 endpoint_resolver: ImdsEndpointResolver,
242 retry_config: RetryConfig,
243 retry_classifier: SharedRetryClassifier,
244 timeout_config: TimeoutConfig,
245 ) -> Self {
246 let mut layer = Layer::new("ImdsCommonRuntimePlugin");
247 layer.store_put(AuthSchemeOptionResolverParams::new(()));
248 layer.store_put(EndpointResolverParams::new(()));
249 layer.store_put(SensitiveOutput);
250 layer.store_put(retry_config);
251 layer.store_put(timeout_config);
252 layer.store_put(user_agent());
253
254 Self {
255 config: layer.freeze(),
256 components: RuntimeComponentsBuilder::new("ImdsCommonRuntimePlugin")
257 .with_http_client(config.http_client())
258 .with_endpoint_resolver(Some(endpoint_resolver))
259 .with_interceptor(UserAgentInterceptor::new())
260 .with_retry_classifier(retry_classifier)
261 .with_retry_strategy(Some(StandardRetryStrategy::new()))
262 .with_time_source(Some(config.time_source()))
263 .with_sleep_impl(config.sleep_impl()),
264 }
265 }
266}
267
268impl RuntimePlugin for ImdsCommonRuntimePlugin {
269 fn config(&self) -> Option<FrozenLayer> {
270 Some(self.config.clone())
271 }
272
273 fn runtime_components(
274 &self,
275 _current_components: &RuntimeComponentsBuilder,
276 ) -> Cow<'_, RuntimeComponentsBuilder> {
277 Cow::Borrowed(&self.components)
278 }
279}
280
281#[derive(Debug, Clone)]
287#[non_exhaustive]
288pub enum EndpointMode {
289 IpV4,
293 IpV6,
295}
296
297impl FromStr for EndpointMode {
298 type Err = InvalidEndpointMode;
299
300 fn from_str(value: &str) -> Result<Self, Self::Err> {
301 match value {
302 _ if value.eq_ignore_ascii_case("ipv4") => Ok(EndpointMode::IpV4),
303 _ if value.eq_ignore_ascii_case("ipv6") => Ok(EndpointMode::IpV6),
304 other => Err(InvalidEndpointMode::new(other.to_owned())),
305 }
306 }
307}
308
309impl EndpointMode {
310 fn endpoint(&self) -> Uri {
312 match self {
313 EndpointMode::IpV4 => Uri::from_static("http://169.254.169.254"),
314 EndpointMode::IpV6 => Uri::from_static("http://[fd00:ec2::254]"),
315 }
316 }
317}
318
319#[derive(Default, Debug, Clone)]
321pub struct Builder {
322 max_attempts: Option<u32>,
323 endpoint: Option<EndpointSource>,
324 mode_override: Option<EndpointMode>,
325 token_ttl: Option<Duration>,
326 connect_timeout: Option<Duration>,
327 read_timeout: Option<Duration>,
328 operation_timeout: Option<Duration>,
329 operation_attempt_timeout: Option<Duration>,
330 config: Option<ProviderConfig>,
331 retry_classifier: Option<SharedRetryClassifier>,
332}
333
334impl Builder {
335 pub fn max_attempts(mut self, max_attempts: u32) -> Self {
339 self.max_attempts = Some(max_attempts);
340 self
341 }
342
343 pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
357 self.config = Some(provider_config.clone());
358 self
359 }
360
361 pub fn endpoint(mut self, endpoint: impl AsRef<str>) -> Result<Self, BoxError> {
367 let uri: Uri = endpoint.as_ref().parse()?;
368 self.endpoint = Some(EndpointSource::Explicit(uri));
369 Ok(self)
370 }
371
372 pub fn endpoint_mode(mut self, mode: EndpointMode) -> Self {
377 self.mode_override = Some(mode);
378 self
379 }
380
381 pub fn token_ttl(mut self, ttl: Duration) -> Self {
387 self.token_ttl = Some(ttl);
388 self
389 }
390
391 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
395 self.connect_timeout = Some(timeout);
396 self
397 }
398
399 pub fn read_timeout(mut self, timeout: Duration) -> Self {
403 self.read_timeout = Some(timeout);
404 self
405 }
406
407 pub fn operation_timeout(mut self, timeout: Duration) -> Self {
411 self.operation_timeout = Some(timeout);
412 self
413 }
414
415 pub fn operation_attempt_timeout(mut self, timeout: Duration) -> Self {
419 self.operation_attempt_timeout = Some(timeout);
420 self
421 }
422
423 pub fn retry_classifier(mut self, retry_classifier: SharedRetryClassifier) -> Self {
429 self.retry_classifier = Some(retry_classifier);
430 self
431 }
432
433 pub fn build(self) -> Client {
442 let config = self.config.unwrap_or_default();
443 let timeout_config = TimeoutConfig::builder()
444 .connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
445 .read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT))
446 .operation_attempt_timeout(
447 self.operation_attempt_timeout
448 .unwrap_or(DEFAULT_OPERATION_ATTEMPT_TIMEOUT),
449 )
450 .operation_timeout(self.operation_timeout.unwrap_or(DEFAULT_OPERATION_TIMEOUT))
451 .build();
452 let endpoint_source = self
453 .endpoint
454 .unwrap_or_else(|| EndpointSource::Env(config.clone()));
455 let endpoint_resolver = ImdsEndpointResolver {
456 endpoint_source: Arc::new(endpoint_source),
457 mode_override: self.mode_override,
458 };
459 let retry_config = RetryConfig::standard()
460 .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
461 let retry_classifier = self.retry_classifier.unwrap_or(SharedRetryClassifier::new(
462 ImdsResponseRetryClassifier::default(),
463 ));
464 let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new(
465 &config,
466 endpoint_resolver,
467 retry_config,
468 retry_classifier,
469 timeout_config,
470 ));
471 let operation = Operation::builder()
472 .service_name("imds")
473 .operation_name("get")
474 .runtime_plugin(common_plugin.clone())
475 .runtime_plugin(TokenRuntimePlugin::new(
476 common_plugin,
477 self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
478 ))
479 .with_connection_poisoning()
480 .serializer(|path| {
481 Ok(HttpRequest::try_from(
482 http::Request::builder()
483 .uri(path)
484 .body(SdkBody::empty())
485 .expect("valid request"),
486 )
487 .unwrap())
488 })
489 .deserializer(|response| {
490 if response.status().is_success() {
491 std::str::from_utf8(response.body().bytes().expect("non-streaming response"))
492 .map(|data| SensitiveString::from(data.to_string()))
493 .map_err(|_| OrchestratorError::operation(InnerImdsError::InvalidUtf8))
494 } else {
495 Err(OrchestratorError::operation(InnerImdsError::BadStatus))
496 }
497 })
498 .build();
499 Client { operation }
500 }
501}
502
503mod env {
504 pub(super) const ENDPOINT: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT";
505 pub(super) const ENDPOINT_MODE: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE";
506}
507
508mod profile_keys {
509 pub(super) const ENDPOINT: &str = "ec2_metadata_service_endpoint";
510 pub(super) const ENDPOINT_MODE: &str = "ec2_metadata_service_endpoint_mode";
511}
512
513#[derive(Debug, Clone)]
515enum EndpointSource {
516 Explicit(Uri),
517 Env(ProviderConfig),
518}
519
520impl EndpointSource {
521 async fn endpoint(&self, mode_override: Option<EndpointMode>) -> Result<Uri, BuildError> {
522 match self {
523 EndpointSource::Explicit(uri) => {
524 if mode_override.is_some() {
525 tracing::warn!(endpoint = ?uri, mode = ?mode_override,
526 "Endpoint mode override was set in combination with an explicit endpoint. \
527 The mode override will be ignored.")
528 }
529 Ok(uri.clone())
530 }
531 EndpointSource::Env(conf) => {
532 let env = conf.env();
533 let profile = conf.profile().await;
535 let uri_override = if let Ok(uri) = env.get(env::ENDPOINT) {
536 Some(Cow::Owned(uri))
537 } else {
538 profile
539 .and_then(|profile| profile.get(profile_keys::ENDPOINT))
540 .map(Cow::Borrowed)
541 };
542 if let Some(uri) = uri_override {
543 return Uri::try_from(uri.as_ref()).map_err(BuildError::invalid_endpoint_uri);
544 }
545
546 let mode = if let Some(mode) = mode_override {
548 mode
549 } else if let Ok(mode) = env.get(env::ENDPOINT_MODE) {
550 mode.parse::<EndpointMode>()
551 .map_err(BuildError::invalid_endpoint_mode)?
552 } else if let Some(mode) = profile.and_then(|p| p.get(profile_keys::ENDPOINT_MODE))
553 {
554 mode.parse::<EndpointMode>()
555 .map_err(BuildError::invalid_endpoint_mode)?
556 } else {
557 EndpointMode::IpV4
558 };
559
560 Ok(mode.endpoint())
561 }
562 }
563 }
564}
565
566#[derive(Clone, Debug)]
567struct ImdsEndpointResolver {
568 endpoint_source: Arc<EndpointSource>,
569 mode_override: Option<EndpointMode>,
570}
571
572impl ResolveEndpoint for ImdsEndpointResolver {
573 fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a> {
574 EndpointFuture::new(async move {
575 self.endpoint_source
576 .endpoint(self.mode_override.clone())
577 .await
578 .map(|uri| Endpoint::builder().url(uri.to_string()).build())
579 .map_err(|err| err.into())
580 })
581 }
582}
583
584#[derive(Clone, Debug, Default)]
595#[non_exhaustive]
596pub struct ImdsResponseRetryClassifier {
597 retry_connect_timeouts: bool,
598}
599
600impl ImdsResponseRetryClassifier {
601 pub fn with_retry_connect_timeouts(mut self, retry_connect_timeouts: bool) -> Self {
603 self.retry_connect_timeouts = retry_connect_timeouts;
604 self
605 }
606}
607
608impl ClassifyRetry for ImdsResponseRetryClassifier {
609 fn name(&self) -> &'static str {
610 "ImdsResponseRetryClassifier"
611 }
612
613 fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
614 if let Some(response) = ctx.response() {
615 let status = response.status();
616 match status {
617 _ if status.is_server_error() => RetryAction::server_error(),
618 _ if status.as_u16() == 401 => RetryAction::server_error(),
620 _ => RetryAction::NoActionIndicated,
622 }
623 } else if self.retry_connect_timeouts {
624 RetryAction::server_error()
625 } else {
626 RetryAction::NoActionIndicated
631 }
632 }
633}
634
635#[cfg(test)]
636pub(crate) mod test {
637 use crate::imds::client::{Client, EndpointMode, ImdsResponseRetryClassifier};
638 use crate::provider_config::ProviderConfig;
639 use aws_smithy_async::rt::sleep::TokioSleep;
640 use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep};
641 use aws_smithy_runtime::client::http::test_util::{
642 capture_request, ReplayEvent, StaticReplayClient,
643 };
644 use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
645 use aws_smithy_runtime_api::client::interceptors::context::{
646 Input, InterceptorContext, Output,
647 };
648 use aws_smithy_runtime_api::client::orchestrator::{
649 HttpRequest, HttpResponse, OrchestratorError,
650 };
651 use aws_smithy_runtime_api::client::result::ConnectorError;
652 use aws_smithy_runtime_api::client::retries::classifiers::{
653 ClassifyRetry, RetryAction, SharedRetryClassifier,
654 };
655 use aws_smithy_types::body::SdkBody;
656 use aws_smithy_types::error::display::DisplayErrorContext;
657 use aws_types::os_shim_internal::{Env, Fs};
658 use http::header::USER_AGENT;
659 use http::Uri;
660 use serde::Deserialize;
661 use std::collections::HashMap;
662 use std::error::Error;
663 use std::io;
664 use std::time::SystemTime;
665 use std::time::{Duration, UNIX_EPOCH};
666 use tracing_test::traced_test;
667
668 macro_rules! assert_full_error_contains {
669 ($err:expr, $contains:expr) => {
670 let err = $err;
671 let message = format!(
672 "{}",
673 aws_smithy_types::error::display::DisplayErrorContext(&err)
674 );
675 assert!(
676 message.contains($contains),
677 "Error message '{message}' didn't contain text '{}'",
678 $contains
679 );
680 };
681 }
682
683 const TOKEN_A: &str = "AQAEAFTNrA4eEGx0AQgJ1arIq_Cc-t4tWt3fB0Hd8RKhXlKc5ccvhg==";
684 const TOKEN_B: &str = "alternatetoken==";
685
686 pub(crate) fn token_request(base: &str, ttl: u32) -> HttpRequest {
687 http::Request::builder()
688 .uri(format!("{}/latest/api/token", base))
689 .header("x-aws-ec2-metadata-token-ttl-seconds", ttl)
690 .method("PUT")
691 .body(SdkBody::empty())
692 .unwrap()
693 .try_into()
694 .unwrap()
695 }
696
697 pub(crate) fn token_response(ttl: u32, token: &'static str) -> HttpResponse {
698 HttpResponse::try_from(
699 http::Response::builder()
700 .status(200)
701 .header("X-aws-ec2-metadata-token-ttl-seconds", ttl)
702 .body(SdkBody::from(token))
703 .unwrap(),
704 )
705 .unwrap()
706 }
707
708 pub(crate) fn imds_request(path: &'static str, token: &str) -> HttpRequest {
709 http::Request::builder()
710 .uri(Uri::from_static(path))
711 .method("GET")
712 .header("x-aws-ec2-metadata-token", token)
713 .body(SdkBody::empty())
714 .unwrap()
715 .try_into()
716 .unwrap()
717 }
718
719 pub(crate) fn imds_response(body: &'static str) -> HttpResponse {
720 HttpResponse::try_from(
721 http::Response::builder()
722 .status(200)
723 .body(SdkBody::from(body))
724 .unwrap(),
725 )
726 .unwrap()
727 }
728
729 pub(crate) fn make_imds_client(http_client: &StaticReplayClient) -> super::Client {
730 tokio::time::pause();
731 super::Client::builder()
732 .configure(
733 &ProviderConfig::no_configuration()
734 .with_sleep_impl(InstantSleep::unlogged())
735 .with_http_client(http_client.clone()),
736 )
737 .build()
738 }
739
740 fn mock_imds_client(events: Vec<ReplayEvent>) -> (Client, StaticReplayClient) {
741 let http_client = StaticReplayClient::new(events);
742 let client = make_imds_client(&http_client);
743 (client, http_client)
744 }
745
746 #[tokio::test]
747 async fn client_caches_token() {
748 let (client, http_client) = mock_imds_client(vec![
749 ReplayEvent::new(
750 token_request("http://169.254.169.254", 21600),
751 token_response(21600, TOKEN_A),
752 ),
753 ReplayEvent::new(
754 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
755 imds_response(r#"test-imds-output"#),
756 ),
757 ReplayEvent::new(
758 imds_request("http://169.254.169.254/latest/metadata2", TOKEN_A),
759 imds_response("output2"),
760 ),
761 ]);
762 let metadata = client.get("/latest/metadata").await.expect("failed");
764 assert_eq!("test-imds-output", metadata.as_ref());
765 let metadata = client.get("/latest/metadata2").await.expect("failed");
767 assert_eq!("output2", metadata.as_ref());
768 http_client.assert_requests_match(&[]);
769 }
770
771 #[tokio::test]
772 async fn token_can_expire() {
773 let (_, http_client) = mock_imds_client(vec![
774 ReplayEvent::new(
775 token_request("http://[fd00:ec2::254]", 600),
776 token_response(600, TOKEN_A),
777 ),
778 ReplayEvent::new(
779 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
780 imds_response(r#"test-imds-output1"#),
781 ),
782 ReplayEvent::new(
783 token_request("http://[fd00:ec2::254]", 600),
784 token_response(600, TOKEN_B),
785 ),
786 ReplayEvent::new(
787 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
788 imds_response(r#"test-imds-output2"#),
789 ),
790 ]);
791 let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
792 let client = super::Client::builder()
793 .configure(
794 &ProviderConfig::no_configuration()
795 .with_http_client(http_client.clone())
796 .with_time_source(time_source.clone())
797 .with_sleep_impl(sleep),
798 )
799 .endpoint_mode(EndpointMode::IpV6)
800 .token_ttl(Duration::from_secs(600))
801 .build();
802
803 let resp1 = client.get("/latest/metadata").await.expect("success");
804 time_source.advance(Duration::from_secs(600));
806 let resp2 = client.get("/latest/metadata").await.expect("success");
807 http_client.assert_requests_match(&[]);
808 assert_eq!("test-imds-output1", resp1.as_ref());
809 assert_eq!("test-imds-output2", resp2.as_ref());
810 }
811
812 #[tokio::test]
814 async fn token_refresh_buffer() {
815 let _logs = capture_test_logs();
816 let (_, http_client) = mock_imds_client(vec![
817 ReplayEvent::new(
818 token_request("http://[fd00:ec2::254]", 600),
819 token_response(600, TOKEN_A),
820 ),
821 ReplayEvent::new(
823 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
824 imds_response(r#"test-imds-output1"#),
825 ),
826 ReplayEvent::new(
828 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
829 imds_response(r#"test-imds-output2"#),
830 ),
831 ReplayEvent::new(
833 token_request("http://[fd00:ec2::254]", 600),
834 token_response(600, TOKEN_B),
835 ),
836 ReplayEvent::new(
837 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
838 imds_response(r#"test-imds-output3"#),
839 ),
840 ]);
841 let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
842 let client = super::Client::builder()
843 .configure(
844 &ProviderConfig::no_configuration()
845 .with_sleep_impl(sleep)
846 .with_http_client(http_client.clone())
847 .with_time_source(time_source.clone()),
848 )
849 .endpoint_mode(EndpointMode::IpV6)
850 .token_ttl(Duration::from_secs(600))
851 .build();
852
853 tracing::info!("resp1 -----------------------------------------------------------");
854 let resp1 = client.get("/latest/metadata").await.expect("success");
855 time_source.advance(Duration::from_secs(400));
857 tracing::info!("resp2 -----------------------------------------------------------");
858 let resp2 = client.get("/latest/metadata").await.expect("success");
859 time_source.advance(Duration::from_secs(150));
860 tracing::info!("resp3 -----------------------------------------------------------");
861 let resp3 = client.get("/latest/metadata").await.expect("success");
862 http_client.assert_requests_match(&[]);
863 assert_eq!("test-imds-output1", resp1.as_ref());
864 assert_eq!("test-imds-output2", resp2.as_ref());
865 assert_eq!("test-imds-output3", resp3.as_ref());
866 }
867
868 #[tokio::test]
870 #[traced_test]
871 async fn retry_500() {
872 let (client, http_client) = mock_imds_client(vec![
873 ReplayEvent::new(
874 token_request("http://169.254.169.254", 21600),
875 token_response(21600, TOKEN_A),
876 ),
877 ReplayEvent::new(
878 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
879 http::Response::builder()
880 .status(500)
881 .body(SdkBody::empty())
882 .unwrap(),
883 ),
884 ReplayEvent::new(
885 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
886 imds_response("ok"),
887 ),
888 ]);
889 assert_eq!(
890 "ok",
891 client
892 .get("/latest/metadata")
893 .await
894 .expect("success")
895 .as_ref()
896 );
897 http_client.assert_requests_match(&[]);
898
899 for request in http_client.actual_requests() {
901 assert!(request.headers().get(USER_AGENT).is_some());
902 }
903 }
904
905 #[tokio::test]
907 #[traced_test]
908 async fn retry_token_failure() {
909 let (client, http_client) = mock_imds_client(vec![
910 ReplayEvent::new(
911 token_request("http://169.254.169.254", 21600),
912 http::Response::builder()
913 .status(500)
914 .body(SdkBody::empty())
915 .unwrap(),
916 ),
917 ReplayEvent::new(
918 token_request("http://169.254.169.254", 21600),
919 token_response(21600, TOKEN_A),
920 ),
921 ReplayEvent::new(
922 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
923 imds_response("ok"),
924 ),
925 ]);
926 assert_eq!(
927 "ok",
928 client
929 .get("/latest/metadata")
930 .await
931 .expect("success")
932 .as_ref()
933 );
934 http_client.assert_requests_match(&[]);
935 }
936
937 #[tokio::test]
939 #[traced_test]
940 async fn retry_metadata_401() {
941 let (client, http_client) = mock_imds_client(vec![
942 ReplayEvent::new(
943 token_request("http://169.254.169.254", 21600),
944 token_response(0, TOKEN_A),
945 ),
946 ReplayEvent::new(
947 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
948 http::Response::builder()
949 .status(401)
950 .body(SdkBody::empty())
951 .unwrap(),
952 ),
953 ReplayEvent::new(
954 token_request("http://169.254.169.254", 21600),
955 token_response(21600, TOKEN_B),
956 ),
957 ReplayEvent::new(
958 imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
959 imds_response("ok"),
960 ),
961 ]);
962 assert_eq!(
963 "ok",
964 client
965 .get("/latest/metadata")
966 .await
967 .expect("success")
968 .as_ref()
969 );
970 http_client.assert_requests_match(&[]);
971 }
972
973 #[tokio::test]
975 #[traced_test]
976 async fn no_403_retry() {
977 let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
978 token_request("http://169.254.169.254", 21600),
979 http::Response::builder()
980 .status(403)
981 .body(SdkBody::empty())
982 .unwrap(),
983 )]);
984 let err = client.get("/latest/metadata").await.expect_err("no token");
985 assert_full_error_contains!(err, "forbidden");
986 http_client.assert_requests_match(&[]);
987 }
988
989 #[test]
991 fn successful_response_properly_classified() {
992 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
993 ctx.set_output_or_error(Ok(Output::doesnt_matter()));
994 ctx.set_response(imds_response("").map(|_| SdkBody::empty()));
995 let classifier = ImdsResponseRetryClassifier::default();
996 assert_eq!(
997 RetryAction::NoActionIndicated,
998 classifier.classify_retry(&ctx)
999 );
1000
1001 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
1003 ctx.set_output_or_error(Err(OrchestratorError::connector(ConnectorError::io(
1004 io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse").into(),
1005 ))));
1006 assert_eq!(
1007 RetryAction::NoActionIndicated,
1008 classifier.classify_retry(&ctx)
1009 );
1010 }
1011
1012 #[tokio::test]
1014 async fn user_provided_retry_classifier() {
1015 #[derive(Clone, Debug)]
1016 struct UserProvidedRetryClassifier;
1017
1018 impl ClassifyRetry for UserProvidedRetryClassifier {
1019 fn name(&self) -> &'static str {
1020 "UserProvidedRetryClassifier"
1021 }
1022
1023 fn classify_retry(&self, _ctx: &InterceptorContext) -> RetryAction {
1025 RetryAction::RetryForbidden
1026 }
1027 }
1028
1029 let events = vec![
1030 ReplayEvent::new(
1031 token_request("http://169.254.169.254", 21600),
1032 token_response(0, TOKEN_A),
1033 ),
1034 ReplayEvent::new(
1035 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1036 http::Response::builder()
1037 .status(401)
1038 .body(SdkBody::empty())
1039 .unwrap(),
1040 ),
1041 ReplayEvent::new(
1042 token_request("http://169.254.169.254", 21600),
1043 token_response(21600, TOKEN_B),
1044 ),
1045 ReplayEvent::new(
1046 imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
1047 imds_response("ok"),
1048 ),
1049 ];
1050 let http_client = StaticReplayClient::new(events);
1051
1052 let imds_client = super::Client::builder()
1053 .configure(
1054 &ProviderConfig::no_configuration()
1055 .with_sleep_impl(InstantSleep::unlogged())
1056 .with_http_client(http_client.clone()),
1057 )
1058 .retry_classifier(SharedRetryClassifier::new(UserProvidedRetryClassifier))
1059 .build();
1060
1061 let res = imds_client
1062 .get("/latest/metadata")
1063 .await
1064 .expect_err("Client should error");
1065
1066 assert_full_error_contains!(res, "401");
1069 }
1070
1071 #[tokio::test]
1073 async fn invalid_token() {
1074 let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
1075 token_request("http://169.254.169.254", 21600),
1076 token_response(21600, "invalid\nheader\nvalue\0"),
1077 )]);
1078 let err = client.get("/latest/metadata").await.expect_err("no token");
1079 assert_full_error_contains!(err, "invalid token");
1080 http_client.assert_requests_match(&[]);
1081 }
1082
1083 #[tokio::test]
1084 async fn non_utf8_response() {
1085 let (client, http_client) = mock_imds_client(vec![
1086 ReplayEvent::new(
1087 token_request("http://169.254.169.254", 21600),
1088 token_response(21600, TOKEN_A).map(SdkBody::from),
1089 ),
1090 ReplayEvent::new(
1091 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1092 http::Response::builder()
1093 .status(200)
1094 .body(SdkBody::from(vec![0xA0, 0xA1]))
1095 .unwrap(),
1096 ),
1097 ]);
1098 let err = client.get("/latest/metadata").await.expect_err("no token");
1099 assert_full_error_contains!(err, "invalid UTF-8");
1100 http_client.assert_requests_match(&[]);
1101 }
1102
1103 #[cfg_attr(windows, ignore)]
1105 #[tokio::test]
1107 #[cfg(feature = "rustls")]
1108 async fn one_second_connect_timeout() {
1109 use crate::imds::client::ImdsError;
1110 let client = Client::builder()
1111 .endpoint("http://240.0.0.0")
1113 .expect("valid uri")
1114 .build();
1115 let now = SystemTime::now();
1116 let resp = client
1117 .get("/latest/metadata")
1118 .await
1119 .expect_err("240.0.0.0 will never resolve");
1120 match resp {
1121 err @ ImdsError::FailedToLoadToken(_)
1122 if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} other => panic!(
1124 "wrong error, expected construction failure with TimedOutError inside: {}",
1125 DisplayErrorContext(&other)
1126 ),
1127 }
1128 let time_elapsed = now.elapsed().unwrap();
1129 assert!(
1130 time_elapsed > Duration::from_secs(1),
1131 "time_elapsed should be greater than 1s but was {:?}",
1132 time_elapsed
1133 );
1134 assert!(
1135 time_elapsed < Duration::from_secs(2),
1136 "time_elapsed should be less than 2s but was {:?}",
1137 time_elapsed
1138 );
1139 }
1140
1141 #[tokio::test]
1143 async fn retry_connect_timeouts() {
1144 let http_client = StaticReplayClient::new(vec![]);
1145 let imds_client = super::Client::builder()
1146 .retry_classifier(SharedRetryClassifier::new(
1147 ImdsResponseRetryClassifier::default().with_retry_connect_timeouts(true),
1148 ))
1149 .configure(&ProviderConfig::no_configuration().with_http_client(http_client.clone()))
1150 .operation_timeout(Duration::from_secs(1))
1151 .endpoint("http://240.0.0.0")
1152 .expect("valid uri")
1153 .build();
1154
1155 let now = SystemTime::now();
1156 let _res = imds_client
1157 .get("/latest/metadata")
1158 .await
1159 .expect_err("240.0.0.0 will never resolve");
1160 let time_elapsed: Duration = now.elapsed().unwrap();
1161
1162 assert!(
1163 time_elapsed > Duration::from_secs(1),
1164 "time_elapsed should be greater than 1s but was {:?}",
1165 time_elapsed
1166 );
1167
1168 assert!(
1169 time_elapsed < Duration::from_secs(2),
1170 "time_elapsed should be less than 2s but was {:?}",
1171 time_elapsed
1172 );
1173 }
1174
1175 #[derive(Debug, Deserialize)]
1176 struct ImdsConfigTest {
1177 env: HashMap<String, String>,
1178 fs: HashMap<String, String>,
1179 endpoint_override: Option<String>,
1180 mode_override: Option<String>,
1181 result: Result<String, String>,
1182 docs: String,
1183 }
1184
1185 #[tokio::test]
1186 async fn endpoint_config_tests() -> Result<(), Box<dyn Error>> {
1187 let _logs = capture_test_logs();
1188
1189 let test_cases = std::fs::read_to_string("test-data/imds-config/imds-endpoint-tests.json")?;
1190 #[derive(Deserialize)]
1191 struct TestCases {
1192 tests: Vec<ImdsConfigTest>,
1193 }
1194
1195 let test_cases: TestCases = serde_json::from_str(&test_cases)?;
1196 let test_cases = test_cases.tests;
1197 for test in test_cases {
1198 check(test).await;
1199 }
1200 Ok(())
1201 }
1202
1203 async fn check(test_case: ImdsConfigTest) {
1204 let (http_client, watcher) = capture_request(None);
1205 let provider_config = ProviderConfig::no_configuration()
1206 .with_sleep_impl(TokioSleep::new())
1207 .with_env(Env::from(test_case.env))
1208 .with_fs(Fs::from_map(test_case.fs))
1209 .with_http_client(http_client);
1210 let mut imds_client = Client::builder().configure(&provider_config);
1211 if let Some(endpoint_override) = test_case.endpoint_override {
1212 imds_client = imds_client
1213 .endpoint(endpoint_override)
1214 .expect("invalid URI");
1215 }
1216
1217 if let Some(mode_override) = test_case.mode_override {
1218 imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap());
1219 }
1220
1221 let imds_client = imds_client.build();
1222 match &test_case.result {
1223 Ok(uri) => {
1224 let _ = imds_client.get("/hello").await;
1226 assert_eq!(&watcher.expect_request().uri().to_string(), uri);
1227 }
1228 Err(expected) => {
1229 let err = imds_client.get("/hello").await.expect_err("it should fail");
1230 let message = format!("{}", DisplayErrorContext(&err));
1231 assert!(
1232 message.contains(expected),
1233 "{}\nexpected error: {expected}\nactual error: {message}",
1234 test_case.docs
1235 );
1236 }
1237 };
1238 }
1239}