use crate::connector::expect_connector;
use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode};
use crate::imds::client::token::TokenMiddleware;
use crate::provider_config::ProviderConfig;
use crate::PKG_VERSION;
use aws_http::user_agent::{ApiMetadata, AwsUserAgent, UserAgentStage};
use aws_sdk_sso::config::timeout::TimeoutConfig;
use aws_smithy_client::http_connector::ConnectorSettings;
use aws_smithy_client::{erase::DynConnector, SdkSuccess};
use aws_smithy_client::{retry, SdkError};
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::endpoint::apply_endpoint;
use aws_smithy_http::operation;
use aws_smithy_http::operation::{Metadata, Operation};
use aws_smithy_http::response::ParseStrictResponse;
use aws_smithy_http::retry::ClassifyRetry;
use aws_smithy_http_tower::map_request::{
AsyncMapRequestLayer, AsyncMapRequestService, MapRequestLayer, MapRequestService,
};
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_smithy_types::retry::{ErrorKind, RetryKind};
use aws_types::os_shim_internal::Env;
use bytes::Bytes;
use http::{Response, Uri};
use std::borrow::Cow;
use std::error::Error;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::OnceCell;
pub mod error;
mod token;
const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
const DEFAULT_ATTEMPTS: u32 = 4;
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
fn user_agent() -> AwsUserAgent {
AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
}
#[derive(Clone, Debug)]
pub struct Client {
inner: Arc<ClientInner>,
}
#[derive(Debug)]
struct ClientInner {
endpoint: Uri,
smithy_client: aws_smithy_client::Client<DynConnector, ImdsMiddleware>,
}
#[derive(Debug)]
pub(super) struct LazyClient {
client: OnceCell<Result<Client, BuildError>>,
builder: Builder,
}
impl LazyClient {
pub(super) fn from_ready_client(client: Client) -> Self {
Self {
client: OnceCell::from(Ok(client)),
builder: Builder::default(),
}
}
pub(super) async fn client(&self) -> Result<&Client, &BuildError> {
let builder = &self.builder;
self.client
.get_or_init(|| async {
let client = builder.clone().build().await;
if let Err(err) = &client {
tracing::warn!(err = %DisplayErrorContext(err), "failed to create IMDS client")
}
client
})
.await
.as_ref()
}
}
impl Client {
pub fn builder() -> Builder {
Builder::default()
}
pub async fn get(&self, path: &str) -> Result<String, ImdsError> {
let operation = self.make_operation(path)?;
self.inner
.smithy_client
.call(operation)
.await
.map_err(|err| match err {
SdkError::ConstructionFailure(_) if err.source().is_some() => {
match err.into_source().map(|e| e.downcast::<ImdsError>()) {
Ok(Ok(token_failure)) => *token_failure,
Ok(Err(err)) => ImdsError::unexpected(err),
Err(err) => ImdsError::unexpected(err),
}
}
SdkError::ConstructionFailure(_) => ImdsError::unexpected(err),
SdkError::ServiceError(context) => match context.err() {
InnerImdsError::InvalidUtf8 => {
ImdsError::unexpected("IMDS returned invalid UTF-8")
}
InnerImdsError::BadStatus => {
ImdsError::error_response(context.into_raw().into_parts().0)
}
},
SdkError::TimeoutError(_)
| SdkError::DispatchFailure(_)
| SdkError::ResponseError(_) => ImdsError::io_error(err),
_ => ImdsError::unexpected(err),
})
}
fn make_operation(
&self,
path: &str,
) -> Result<Operation<ImdsGetResponseHandler, ImdsResponseRetryClassifier>, ImdsError> {
let mut base_uri: Uri = path.parse().map_err(|_| {
ImdsError::unexpected("IMDS path was not a valid URI. Hint: does it begin with `/`?")
})?;
apply_endpoint(&mut base_uri, &self.inner.endpoint, None).map_err(ImdsError::unexpected)?;
let request = http::Request::builder()
.uri(base_uri)
.body(SdkBody::empty())
.expect("valid request");
let mut request = operation::Request::new(request);
request.properties_mut().insert(user_agent());
Ok(Operation::new(request, ImdsGetResponseHandler)
.with_metadata(Metadata::new("get", "imds"))
.with_retry_classifier(ImdsResponseRetryClassifier))
}
}
#[derive(Clone, Debug)]
struct ImdsMiddleware {
token_loader: TokenMiddleware,
}
impl<S> tower::Layer<S> for ImdsMiddleware {
type Service = AsyncMapRequestService<MapRequestService<S, UserAgentStage>, TokenMiddleware>;
fn layer(&self, inner: S) -> Self::Service {
AsyncMapRequestLayer::for_mapper(self.token_loader.clone())
.layer(MapRequestLayer::for_mapper(UserAgentStage::new()).layer(inner))
}
}
#[derive(Copy, Clone)]
struct ImdsGetResponseHandler;
impl ParseStrictResponse for ImdsGetResponseHandler {
type Output = Result<String, InnerImdsError>;
fn parse(&self, response: &Response<Bytes>) -> Self::Output {
if response.status().is_success() {
std::str::from_utf8(response.body().as_ref())
.map(|data| data.to_string())
.map_err(|_| InnerImdsError::InvalidUtf8)
} else {
Err(InnerImdsError::BadStatus)
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum EndpointMode {
IpV4,
IpV6,
}
impl FromStr for EndpointMode {
type Err = InvalidEndpointMode;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value {
_ if value.eq_ignore_ascii_case("ipv4") => Ok(EndpointMode::IpV4),
_ if value.eq_ignore_ascii_case("ipv6") => Ok(EndpointMode::IpV6),
other => Err(InvalidEndpointMode::new(other.to_owned())),
}
}
}
impl EndpointMode {
fn endpoint(&self) -> Uri {
match self {
EndpointMode::IpV4 => Uri::from_static("http://169.254.169.254"),
EndpointMode::IpV6 => Uri::from_static("http://[fd00:ec2::254]"),
}
}
}
#[derive(Default, Debug, Clone)]
pub struct Builder {
max_attempts: Option<u32>,
endpoint: Option<EndpointSource>,
mode_override: Option<EndpointMode>,
token_ttl: Option<Duration>,
connect_timeout: Option<Duration>,
read_timeout: Option<Duration>,
config: Option<ProviderConfig>,
}
impl Builder {
pub fn max_attempts(mut self, max_attempts: u32) -> Self {
self.max_attempts = Some(max_attempts);
self
}
pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
self.config = Some(provider_config.clone());
self
}
pub fn endpoint(mut self, endpoint: impl Into<Uri>) -> Self {
self.endpoint = Some(EndpointSource::Explicit(endpoint.into()));
self
}
pub fn endpoint_mode(mut self, mode: EndpointMode) -> Self {
self.mode_override = Some(mode);
self
}
pub fn token_ttl(mut self, ttl: Duration) -> Self {
self.token_ttl = Some(ttl);
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = Some(timeout);
self
}
pub fn read_timeout(mut self, timeout: Duration) -> Self {
self.read_timeout = Some(timeout);
self
}
pub(super) fn build_lazy(self) -> LazyClient {
LazyClient {
client: OnceCell::new(),
builder: self,
}
}
pub async fn build(self) -> Result<Client, BuildError> {
let config = self.config.unwrap_or_default();
let timeout_config = TimeoutConfig::builder()
.connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
.read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT))
.build();
let connector_settings = ConnectorSettings::from_timeout_config(&timeout_config);
let connector = expect_connector(config.connector(&connector_settings));
let endpoint_source = self
.endpoint
.unwrap_or_else(|| EndpointSource::Env(config.clone()));
let endpoint = endpoint_source.endpoint(self.mode_override).await?;
let retry_config = retry::Config::default()
.with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
let token_loader = token::TokenMiddleware::new(
connector.clone(),
config.time_source(),
endpoint.clone(),
self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
retry_config.clone(),
timeout_config.clone(),
config.sleep(),
);
let middleware = ImdsMiddleware { token_loader };
let mut smithy_builder = aws_smithy_client::Client::builder()
.connector(connector.clone())
.middleware(middleware)
.retry_config(retry_config)
.operation_timeout_config(timeout_config.into());
smithy_builder.set_sleep_impl(config.sleep());
let smithy_client = smithy_builder.build();
let client = Client {
inner: Arc::new(ClientInner {
endpoint,
smithy_client,
}),
};
Ok(client)
}
}
mod env {
pub(super) const ENDPOINT: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT";
pub(super) const ENDPOINT_MODE: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE";
}
mod profile_keys {
pub(super) const ENDPOINT: &str = "ec2_metadata_service_endpoint";
pub(super) const ENDPOINT_MODE: &str = "ec2_metadata_service_endpoint_mode";
}
#[derive(Debug, Clone)]
enum EndpointSource {
Explicit(Uri),
Env(ProviderConfig),
}
impl EndpointSource {
async fn endpoint(&self, mode_override: Option<EndpointMode>) -> Result<Uri, BuildError> {
match self {
EndpointSource::Explicit(uri) => {
if mode_override.is_some() {
tracing::warn!(endpoint = ?uri, mode = ?mode_override,
"Endpoint mode override was set in combination with an explicit endpoint. \
The mode override will be ignored.")
}
Ok(uri.clone())
}
EndpointSource::Env(conf) => {
let env = conf.env();
let profile = conf.profile().await;
let uri_override = if let Ok(uri) = env.get(env::ENDPOINT) {
Some(Cow::Owned(uri))
} else {
profile
.and_then(|profile| profile.get(profile_keys::ENDPOINT))
.map(Cow::Borrowed)
};
if let Some(uri) = uri_override {
return Uri::try_from(uri.as_ref()).map_err(BuildError::invalid_endpoint_uri);
}
let mode = if let Some(mode) = mode_override {
mode
} else if let Ok(mode) = env.get(env::ENDPOINT_MODE) {
mode.parse::<EndpointMode>()
.map_err(BuildError::invalid_endpoint_mode)?
} else if let Some(mode) = profile.and_then(|p| p.get(profile_keys::ENDPOINT_MODE))
{
mode.parse::<EndpointMode>()
.map_err(BuildError::invalid_endpoint_mode)?
} else {
EndpointMode::IpV4
};
Ok(mode.endpoint())
}
}
}
}
#[derive(Clone)]
struct ImdsResponseRetryClassifier;
impl ImdsResponseRetryClassifier {
fn classify(response: &operation::Response) -> RetryKind {
let status = response.http().status();
match status {
_ if status.is_server_error() => RetryKind::Error(ErrorKind::ServerError),
_ if status.as_u16() == 401 => RetryKind::Error(ErrorKind::ServerError),
_ => RetryKind::UnretryableFailure,
}
}
}
impl<T, E> ClassifyRetry<SdkSuccess<T>, SdkError<E>> for ImdsResponseRetryClassifier {
fn classify_retry(&self, response: Result<&SdkSuccess<T>, &SdkError<E>>) -> RetryKind {
match response {
Ok(_) => RetryKind::Unnecessary,
Err(SdkError::ResponseError(context)) => Self::classify(context.raw()),
Err(SdkError::ServiceError(context)) => Self::classify(context.raw()),
_ => RetryKind::UnretryableFailure,
}
}
}
#[cfg(test)]
pub(crate) mod test {
use crate::imds::client::{Client, EndpointMode, ImdsResponseRetryClassifier};
use crate::provider_config::ProviderConfig;
use aws_credential_types::time_source::{TestingTimeSource, TimeSource};
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_smithy_client::erase::DynConnector;
use aws_smithy_client::test_connection::{capture_request, TestConnection};
use aws_smithy_client::{SdkError, SdkSuccess};
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::operation;
use aws_smithy_types::retry::RetryKind;
use aws_types::os_shim_internal::{Env, Fs};
use http::header::USER_AGENT;
use http::Uri;
use serde::Deserialize;
use std::collections::HashMap;
use std::error::Error;
use std::io;
use std::time::{Duration, UNIX_EPOCH};
use tracing_test::traced_test;
macro_rules! assert_full_error_contains {
($err:expr, $contains:expr) => {
let err = $err;
let message = format!(
"{}",
aws_smithy_types::error::display::DisplayErrorContext(&err)
);
assert!(
message.contains($contains),
"Error message '{message}' didn't contain text '{}'",
$contains
);
};
}
const TOKEN_A: &str = "AQAEAFTNrA4eEGx0AQgJ1arIq_Cc-t4tWt3fB0Hd8RKhXlKc5ccvhg==";
const TOKEN_B: &str = "alternatetoken==";
pub(crate) fn token_request(base: &str, ttl: u32) -> http::Request<SdkBody> {
http::Request::builder()
.uri(format!("{}/latest/api/token", base))
.header("x-aws-ec2-metadata-token-ttl-seconds", ttl)
.method("PUT")
.body(SdkBody::empty())
.unwrap()
}
pub(crate) fn token_response(ttl: u32, token: &'static str) -> http::Response<&'static str> {
http::Response::builder()
.status(200)
.header("X-aws-ec2-metadata-token-ttl-seconds", ttl)
.body(token)
.unwrap()
}
pub(crate) fn imds_request(path: &'static str, token: &str) -> http::Request<SdkBody> {
http::Request::builder()
.uri(Uri::from_static(path))
.method("GET")
.header("x-aws-ec2-metadata-token", token)
.body(SdkBody::empty())
.unwrap()
}
pub(crate) fn imds_response(body: &'static str) -> http::Response<&'static str> {
http::Response::builder().status(200).body(body).unwrap()
}
pub(crate) async fn make_client<T>(conn: &TestConnection<T>) -> super::Client
where
SdkBody: From<T>,
T: Send + 'static,
{
tokio::time::pause();
super::Client::builder()
.configure(
&ProviderConfig::no_configuration()
.with_sleep(TokioSleep::new())
.with_http_connector(DynConnector::new(conn.clone())),
)
.build()
.await
.expect("valid client")
}
#[tokio::test]
async fn client_caches_token() {
let connection = TestConnection::new(vec![
(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A),
),
(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
imds_response(r#"test-imds-output"#),
),
(
imds_request("http://169.254.169.254/latest/metadata2", TOKEN_A),
imds_response("output2"),
),
]);
let client = make_client(&connection).await;
let metadata = client.get("/latest/metadata").await.expect("failed");
assert_eq!(metadata, "test-imds-output");
let metadata = client.get("/latest/metadata2").await.expect("failed");
assert_eq!(metadata, "output2");
connection.assert_requests_match(&[]);
}
#[tokio::test]
async fn token_can_expire() {
let connection = TestConnection::new(vec![
(
token_request("http://[fd00:ec2::254]", 600),
token_response(600, TOKEN_A),
),
(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
imds_response(r#"test-imds-output1"#),
),
(
token_request("http://[fd00:ec2::254]", 600),
token_response(600, TOKEN_B),
),
(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
imds_response(r#"test-imds-output2"#),
),
]);
let mut time_source = TestingTimeSource::new(UNIX_EPOCH);
tokio::time::pause();
let client = super::Client::builder()
.configure(
&ProviderConfig::no_configuration()
.with_http_connector(DynConnector::new(connection.clone()))
.with_time_source(TimeSource::testing(&time_source))
.with_sleep(TokioSleep::new()),
)
.endpoint_mode(EndpointMode::IpV6)
.token_ttl(Duration::from_secs(600))
.build()
.await
.expect("valid client");
let resp1 = client.get("/latest/metadata").await.expect("success");
time_source.advance(Duration::from_secs(600));
let resp2 = client.get("/latest/metadata").await.expect("success");
connection.assert_requests_match(&[]);
assert_eq!(resp1, "test-imds-output1");
assert_eq!(resp2, "test-imds-output2");
}
#[tokio::test]
async fn token_refresh_buffer() {
let connection = TestConnection::new(vec![
(
token_request("http://[fd00:ec2::254]", 600),
token_response(600, TOKEN_A),
),
(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
imds_response(r#"test-imds-output1"#),
),
(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
imds_response(r#"test-imds-output2"#),
),
(
token_request("http://[fd00:ec2::254]", 600),
token_response(600, TOKEN_B),
),
(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
imds_response(r#"test-imds-output3"#),
),
]);
tokio::time::pause();
let mut time_source = TestingTimeSource::new(UNIX_EPOCH);
let client = super::Client::builder()
.configure(
&ProviderConfig::no_configuration()
.with_sleep(TokioSleep::new())
.with_http_connector(DynConnector::new(connection.clone()))
.with_time_source(TimeSource::testing(&time_source)),
)
.endpoint_mode(EndpointMode::IpV6)
.token_ttl(Duration::from_secs(600))
.build()
.await
.expect("valid client");
let resp1 = client.get("/latest/metadata").await.expect("success");
time_source.advance(Duration::from_secs(400));
let resp2 = client.get("/latest/metadata").await.expect("success");
time_source.advance(Duration::from_secs(150));
let resp3 = client.get("/latest/metadata").await.expect("success");
connection.assert_requests_match(&[]);
assert_eq!(resp1, "test-imds-output1");
assert_eq!(resp2, "test-imds-output2");
assert_eq!(resp3, "test-imds-output3");
}
#[tokio::test]
#[traced_test]
async fn retry_500() {
let connection = TestConnection::new(vec![
(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A),
),
(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
http::Response::builder().status(500).body("").unwrap(),
),
(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
imds_response("ok"),
),
]);
let client = make_client(&connection).await;
assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok");
connection.assert_requests_match(&[]);
for request in connection.requests().iter() {
assert!(request.actual.headers().get(USER_AGENT).is_some());
}
}
#[tokio::test]
#[traced_test]
async fn retry_token_failure() {
let connection = TestConnection::new(vec![
(
token_request("http://169.254.169.254", 21600),
http::Response::builder().status(500).body("").unwrap(),
),
(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A),
),
(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
imds_response("ok"),
),
]);
let client = make_client(&connection).await;
assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok");
connection.assert_requests_match(&[]);
}
#[tokio::test]
#[traced_test]
async fn retry_metadata_401() {
let connection = TestConnection::new(vec![
(
token_request("http://169.254.169.254", 21600),
token_response(0, TOKEN_A),
),
(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
http::Response::builder().status(401).body("").unwrap(),
),
(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_B),
),
(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
imds_response("ok"),
),
]);
let client = make_client(&connection).await;
assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok");
connection.assert_requests_match(&[]);
}
#[tokio::test]
#[traced_test]
async fn no_403_retry() {
let connection = TestConnection::new(vec![(
token_request("http://169.254.169.254", 21600),
http::Response::builder().status(403).body("").unwrap(),
)]);
let client = make_client(&connection).await;
let err = client.get("/latest/metadata").await.expect_err("no token");
assert_full_error_contains!(err, "forbidden");
connection.assert_requests_match(&[]);
}
#[test]
fn successful_response_properly_classified() {
use aws_smithy_http::retry::ClassifyRetry;
let classifier = ImdsResponseRetryClassifier;
fn response_200() -> operation::Response {
operation::Response::new(imds_response("").map(|_| SdkBody::empty()))
}
let success = SdkSuccess {
raw: response_200(),
parsed: (),
};
assert_eq!(
RetryKind::Unnecessary,
classifier.classify_retry(Ok::<_, &SdkError<()>>(&success))
);
let failure = SdkError::<()>::response_error(
io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse"),
response_200(),
);
assert_eq!(
RetryKind::UnretryableFailure,
classifier.classify_retry(Err::<&SdkSuccess<()>, _>(&failure))
);
}
#[tokio::test]
async fn invalid_token() {
let connection = TestConnection::new(vec![(
token_request("http://169.254.169.254", 21600),
token_response(21600, "replaced").map(|_| vec![1, 0]),
)]);
let client = make_client(&connection).await;
let err = client.get("/latest/metadata").await.expect_err("no token");
assert_full_error_contains!(err, "invalid token");
connection.assert_requests_match(&[]);
}
#[tokio::test]
async fn non_utf8_response() {
let connection = TestConnection::new(vec![
(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A).map(SdkBody::from),
),
(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
http::Response::builder()
.status(200)
.body(SdkBody::from(vec![0xA0, 0xA1]))
.unwrap(),
),
]);
let client = make_client(&connection).await;
let err = client.get("/latest/metadata").await.expect_err("no token");
assert_full_error_contains!(err, "invalid UTF-8");
connection.assert_requests_match(&[]);
}
#[tokio::test]
#[cfg(any(feature = "rustls", feature = "native-tls"))]
async fn one_second_connect_timeout() {
use crate::imds::client::ImdsError;
use aws_smithy_types::error::display::DisplayErrorContext;
use std::time::SystemTime;
let client = Client::builder()
.endpoint(Uri::from_static("http://240.0.0.0"))
.build()
.await
.expect("valid client");
let now = SystemTime::now();
let resp = client
.get("/latest/metadata")
.await
.expect_err("240.0.0.0 will never resolve");
let time_elapsed = now.elapsed().unwrap();
assert!(
time_elapsed > Duration::from_secs(1),
"time_elapsed should be greater than 1s but was {:?}",
time_elapsed
);
assert!(
time_elapsed < Duration::from_secs(2),
"time_elapsed should be less than 2s but was {:?}",
time_elapsed
);
match resp {
err @ ImdsError::FailedToLoadToken(_)
if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} other => panic!(
"wrong error, expected construction failure with TimedOutError inside: {}",
other
),
}
}
#[derive(Debug, Deserialize)]
struct ImdsConfigTest {
env: HashMap<String, String>,
fs: HashMap<String, String>,
endpoint_override: Option<String>,
mode_override: Option<String>,
result: Result<String, String>,
docs: String,
}
#[tokio::test]
async fn config_tests() -> Result<(), Box<dyn Error>> {
let test_cases = std::fs::read_to_string("test-data/imds-config/imds-tests.json")?;
#[derive(Deserialize)]
struct TestCases {
tests: Vec<ImdsConfigTest>,
}
let test_cases: TestCases = serde_json::from_str(&test_cases)?;
let test_cases = test_cases.tests;
for test in test_cases {
check(test).await;
}
Ok(())
}
async fn check(test_case: ImdsConfigTest) {
let (server, watcher) = capture_request(None);
let provider_config = ProviderConfig::no_configuration()
.with_sleep(TokioSleep::new())
.with_env(Env::from(test_case.env))
.with_fs(Fs::from_map(test_case.fs))
.with_http_connector(DynConnector::new(server));
let mut imds_client = Client::builder().configure(&provider_config);
if let Some(endpoint_override) = test_case.endpoint_override {
imds_client = imds_client.endpoint(endpoint_override.parse::<Uri>().unwrap());
}
if let Some(mode_override) = test_case.mode_override {
imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap());
}
let imds_client = imds_client.build().await;
let (uri, imds_client) = match (&test_case.result, imds_client) {
(Ok(uri), Ok(client)) => (uri, client),
(Err(test), Ok(_client)) => panic!(
"test should fail: {} but a valid client was made. {}",
test, test_case.docs
),
(Err(substr), Err(err)) => {
assert_full_error_contains!(err, substr);
return;
}
(Ok(_uri), Err(e)) => panic!(
"a valid client should be made but: {}. {}",
e, test_case.docs
),
};
let _ = imds_client.get("/hello").await;
assert_eq!(&watcher.expect_request().uri().to_string(), uri);
}
}