use crate::internals::format_status_code_range;
use crate::internals::DebugResponseBody;
use crate::internals::RequestPathFormatter;
use crate::internals::StatusCodeFormatter;
use crate::internals::TryIntoRangeBounds;
use anyhow::Context;
use assert_json_diff::assert_json_include;
use bytes::Bytes;
use cookie::Cookie;
use cookie::CookieJar;
use http::header::HeaderName;
use http::header::SET_COOKIE;
use http::response::Parts;
use http::HeaderMap;
use http::HeaderValue;
use http::Method;
use http::StatusCode;
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::Value;
use std::convert::AsRef;
use std::fmt::Debug;
use std::fmt::Display;
use std::fs::read_to_string;
use std::fs::File;
use std::io::BufReader;
use std::ops::RangeBounds;
use url::Url;
#[cfg(feature = "pretty-assertions")]
use pretty_assertions::{assert_eq, assert_ne};
#[cfg(feature = "ws")]
use crate::internals::TestResponseWebSocket;
#[cfg(feature = "ws")]
use crate::TestWebSocket;
use std::path::Path;
#[derive(Clone, Debug)]
pub struct TestResponse {
method: Method,
full_request_url: Url,
headers: HeaderMap<HeaderValue>,
status_code: StatusCode,
response_body: Bytes,
#[cfg(feature = "ws")]
websockets: TestResponseWebSocket,
}
impl TestResponse {
pub(crate) fn new(
method: Method,
full_request_url: Url,
parts: Parts,
response_body: Bytes,
#[cfg(feature = "ws")] websockets: TestResponseWebSocket,
) -> Self {
Self {
method,
full_request_url,
headers: parts.headers,
status_code: parts.status,
response_body,
#[cfg(feature = "ws")]
websockets,
}
}
#[must_use]
pub fn text(&self) -> String {
String::from_utf8_lossy(self.as_bytes()).to_string()
}
#[must_use]
pub fn json<T>(&self) -> T
where
T: DeserializeOwned,
{
serde_json::from_slice::<T>(self.as_bytes())
.with_context(|| {
let debug_request_format = self.debug_request_format();
format!("Deserializing response from Json, for request {debug_request_format}")
})
.unwrap()
}
#[cfg(feature = "yaml")]
#[must_use]
pub fn yaml<T>(&self) -> T
where
T: DeserializeOwned,
{
serde_yaml::from_slice::<T>(self.as_bytes())
.with_context(|| {
let debug_request_format = self.debug_request_format();
format!("Deserializing response from YAML, for request {debug_request_format}")
})
.unwrap()
}
#[cfg(feature = "msgpack")]
#[must_use]
pub fn msgpack<T>(&self) -> T
where
T: DeserializeOwned,
{
rmp_serde::from_slice::<T>(self.as_bytes())
.with_context(|| {
let debug_request_format = self.debug_request_format();
format!("Deserializing response from MsgPack, for request {debug_request_format}")
})
.unwrap()
}
#[must_use]
pub fn form<T>(&self) -> T
where
T: DeserializeOwned,
{
serde_urlencoded::from_bytes::<T>(self.as_bytes())
.with_context(|| {
let debug_request_format = self.debug_request_format();
format!("Deserializing response from Form, for request {debug_request_format}")
})
.unwrap()
}
#[must_use]
pub fn as_bytes(&self) -> &Bytes {
&self.response_body
}
#[must_use]
pub fn into_bytes(self) -> Bytes {
self.response_body
}
#[must_use]
pub fn status_code(&self) -> StatusCode {
self.status_code
}
#[must_use]
pub fn request_method(&self) -> Method {
self.method.clone()
}
#[must_use]
pub fn request_url(&self) -> Url {
self.full_request_url.clone()
}
#[must_use]
pub fn maybe_header<N>(&self, name: N) -> Option<HeaderValue>
where
N: TryInto<HeaderName>,
N::Error: Debug,
{
let header_name = name
.try_into()
.expect("Failed to build HeaderName from name given");
self.headers.get(header_name).map(|h| h.to_owned())
}
#[must_use]
pub fn headers(&self) -> &HeaderMap<HeaderValue> {
&self.headers
}
#[must_use]
pub fn maybe_content_type(&self) -> Option<String> {
self.headers.get(http::header::CONTENT_TYPE).map(|header| {
header
.to_str()
.with_context(|| {
format!("Failed to decode header CONTENT_TYPE, received '{header:?}'")
})
.unwrap()
.to_string()
})
}
#[must_use]
pub fn content_type(&self) -> String {
self.maybe_content_type()
.expect("CONTENT_TYPE not found in response header")
}
#[must_use]
pub fn header<N>(&self, name: N) -> HeaderValue
where
N: TryInto<HeaderName> + Display + Clone,
N::Error: Debug,
{
let debug_header = name.clone();
let header_name = name
.try_into()
.expect("Failed to build HeaderName from name given, '{debug_header}'");
self.headers
.get(header_name)
.map(|h| h.to_owned())
.with_context(|| {
let debug_request_format = self.debug_request_format();
format!("Cannot find header {debug_header}, for request {debug_request_format}",)
})
.unwrap()
}
pub fn iter_headers(&self) -> impl Iterator<Item = (&'_ HeaderName, &'_ HeaderValue)> {
self.headers.iter()
}
pub fn iter_headers_by_name<N>(&self, name: N) -> impl Iterator<Item = &'_ HeaderValue>
where
N: TryInto<HeaderName>,
N::Error: Debug,
{
let header_name = name
.try_into()
.expect("Failed to build HeaderName from name given");
self.headers.get_all(header_name).iter()
}
#[must_use]
pub fn contains_header<N>(&self, name: N) -> bool
where
N: TryInto<HeaderName>,
N::Error: Debug,
{
let header_name = name
.try_into()
.expect("Failed to build HeaderName from name given");
self.headers.contains_key(header_name)
}
#[track_caller]
pub fn assert_contains_header<N>(&self, name: N)
where
N: TryInto<HeaderName> + Display + Clone,
N::Error: Debug,
{
let debug_header_name = name.clone();
let debug_request_format = self.debug_request_format();
let has_header = self.contains_header(name);
assert!(has_header, "Expected header '{debug_header_name}' to be present in response, header was not found, for request {debug_request_format}");
}
#[track_caller]
pub fn assert_header<N, V>(&self, name: N, value: V)
where
N: TryInto<HeaderName> + Display + Clone,
N::Error: Debug,
V: TryInto<HeaderValue>,
V::Error: Debug,
{
let debug_header_name = name.clone();
let header_name = name
.try_into()
.expect("Failed to build HeaderName from name given");
let expected_header_value = value
.try_into()
.expect("Could not turn given value into HeaderValue");
let debug_request_format = self.debug_request_format();
let maybe_found_header_value = self.maybe_header(header_name);
match maybe_found_header_value {
None => {
panic!("Expected header '{debug_header_name}' to be present in response, header was not found, for request {debug_request_format}")
}
Some(found_header_value) => {
assert_eq!(expected_header_value, found_header_value,)
}
}
}
#[must_use]
pub fn maybe_cookie(&self, cookie_name: &str) -> Option<Cookie<'static>> {
for cookie in self.iter_cookies() {
if cookie.name() == cookie_name {
return Some(cookie.into_owned());
}
}
None
}
#[must_use]
pub fn cookie(&self, cookie_name: &str) -> Cookie<'static> {
self.maybe_cookie(cookie_name)
.with_context(|| {
let debug_request_format = self.debug_request_format();
format!("Cannot find cookie {cookie_name}, for request {debug_request_format}")
})
.unwrap()
}
#[must_use]
pub fn cookies(&self) -> CookieJar {
let mut cookies = CookieJar::new();
for cookie in self.iter_cookies() {
cookies.add(cookie.into_owned());
}
cookies
}
pub fn iter_cookies(&self) -> impl Iterator<Item = Cookie<'_>> {
self.iter_headers_by_name(SET_COOKIE).map(|header| {
let header_str = header
.to_str()
.with_context(|| {
let debug_request_format = self.debug_request_format();
format!(
"Reading header 'Set-Cookie' as string, for request {debug_request_format}",
)
})
.unwrap();
Cookie::parse(header_str)
.with_context(|| {
let debug_request_format = self.debug_request_format();
format!("Parsing 'Set-Cookie' header, for request {debug_request_format}",)
})
.unwrap()
})
}
#[cfg(feature = "ws")]
#[must_use]
pub async fn into_websocket(self) -> TestWebSocket {
use crate::transport_layer::TransportLayerType;
if self.websockets.transport_type != TransportLayerType::Http {
unimplemented!("WebSocket requires a HTTP based transport layer, see `TestServerConfig::transport`");
}
let debug_request_format = self.debug_request_format().to_string();
let on_upgrade = self.websockets.maybe_on_upgrade.with_context(|| {
format!("Expected WebSocket upgrade to be found, it is None, for request {debug_request_format}")
})
.unwrap();
let upgraded = on_upgrade
.await
.with_context(|| {
format!("Failed to upgrade connection for, for request {debug_request_format}")
})
.unwrap();
TestWebSocket::new(upgraded).await
}
#[track_caller]
pub fn assert_text<C>(&self, expected: C)
where
C: AsRef<str>,
{
let expected_contents = expected.as_ref();
assert_eq!(expected_contents, &self.text());
}
#[track_caller]
pub fn assert_text_contains<C>(&self, expected: C)
where
C: AsRef<str>,
{
let expected_contents = expected.as_ref();
let received = self.text();
let is_contained = received.contains(expected_contents);
assert!(
is_contained,
"Failed to find '{expected_contents}', received '{received}'"
);
}
#[track_caller]
pub fn assert_text_from_file<P>(&self, path: P)
where
P: AsRef<Path>,
{
let path_ref = path.as_ref();
let expected = read_to_string(path_ref)
.with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
.unwrap();
self.assert_text(expected);
}
#[track_caller]
pub fn assert_json<T>(&self, expected: &T)
where
T: DeserializeOwned + PartialEq<T> + Debug,
{
assert_eq!(*expected, self.json::<T>());
}
#[track_caller]
pub fn assert_json_contains<T>(&self, expected: &T)
where
T: Serialize,
{
let received = self.json::<Value>();
assert_json_include!(actual: received, expected: expected);
}
#[track_caller]
pub fn assert_json_from_file<P>(&self, path: P)
where
P: AsRef<Path>,
{
let path_ref = path.as_ref();
let file = File::open(path_ref)
.with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
.unwrap();
let reader = BufReader::new(file);
let expected = serde_json::from_reader::<_, serde_json::Value>(reader)
.with_context(|| {
format!(
"Failed to deserialize file '{}' as json",
path_ref.display()
)
})
.unwrap();
self.assert_json(&expected);
}
#[cfg(feature = "yaml")]
#[track_caller]
pub fn assert_yaml<T>(&self, other: &T)
where
T: DeserializeOwned + PartialEq<T> + Debug,
{
assert_eq!(*other, self.yaml::<T>());
}
#[cfg(feature = "yaml")]
#[track_caller]
pub fn assert_yaml_from_file<P>(&self, path: P)
where
P: AsRef<Path>,
{
let path_ref = path.as_ref();
let file = File::open(path_ref)
.with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
.unwrap();
let reader = BufReader::new(file);
let expected = serde_yaml::from_reader::<_, serde_yaml::Value>(reader)
.with_context(|| {
format!(
"Failed to deserialize file '{}' as yaml",
path_ref.display()
)
})
.unwrap();
self.assert_yaml(&expected);
}
#[cfg(feature = "msgpack")]
#[track_caller]
pub fn assert_msgpack<T>(&self, other: &T)
where
T: DeserializeOwned + PartialEq<T> + Debug,
{
assert_eq!(*other, self.msgpack::<T>());
}
#[track_caller]
pub fn assert_form<T>(&self, other: &T)
where
T: DeserializeOwned + PartialEq<T> + Debug,
{
assert_eq!(*other, self.form::<T>());
}
#[track_caller]
pub fn assert_status(&self, expected_status_code: StatusCode) {
let received_debug = StatusCodeFormatter(self.status_code);
let expected_debug = StatusCodeFormatter(expected_status_code);
let debug_request_format = self.debug_request_format();
let debug_body = DebugResponseBody(self);
assert_eq!(
expected_status_code, self.status_code,
"Expected status code to be {expected_debug}, received {received_debug}, for request {debug_request_format}, with body {debug_body}"
);
}
#[track_caller]
pub fn assert_not_status(&self, expected_status_code: StatusCode) {
let received_debug = StatusCodeFormatter(self.status_code);
let expected_debug = StatusCodeFormatter(expected_status_code);
let debug_request_format = self.debug_request_format();
let debug_body = DebugResponseBody(self);
assert_ne!(
expected_status_code,
self.status_code,
"Expected status code to not be {expected_debug}, received {received_debug}, for request {debug_request_format}, with body {debug_body}"
);
}
#[track_caller]
pub fn assert_status_success(&self) {
let status_code = self.status_code.as_u16();
let received_debug = StatusCodeFormatter(self.status_code);
let debug_request_format = self.debug_request_format();
let debug_body = DebugResponseBody(self);
assert!(
200 <= status_code && status_code <= 299,
"Expect status code within 2xx range, received {received_debug}, for request {debug_request_format}, with body {debug_body}"
);
}
#[track_caller]
pub fn assert_status_failure(&self) {
let status_code = self.status_code.as_u16();
let received_debug = StatusCodeFormatter(self.status_code);
let debug_request_format = self.debug_request_format();
let debug_body = DebugResponseBody(self);
assert!(
status_code < 200 || 299 < status_code,
"Expect status code outside 2xx range, received {received_debug}, for request {debug_request_format}, with body {debug_body}"
);
}
pub fn assert_status_in_range<R, S>(&self, expected_status_range: R)
where
R: RangeBounds<S> + TryIntoRangeBounds<StatusCode> + Debug,
S: TryInto<StatusCode>,
{
let range = TryIntoRangeBounds::<StatusCode>::try_into_range_bounds(expected_status_range)
.expect("Failed to convert status code");
let status_code = self.status_code();
let is_in_range = range.contains(&status_code);
let debug_request_format = self.debug_request_format();
let debug_body = DebugResponseBody(self);
assert!(
is_in_range,
"Expected status to be in range {}, received {status_code}, for request {debug_request_format}, with body {debug_body}",
format_status_code_range(range)
);
}
pub fn assert_status_not_in_range<R, S>(&self, expected_status_range: R)
where
R: RangeBounds<S> + TryIntoRangeBounds<StatusCode> + Debug,
S: TryInto<StatusCode>,
{
let range = TryIntoRangeBounds::<StatusCode>::try_into_range_bounds(expected_status_range)
.expect("Failed to convert status code");
let status_code = self.status_code();
let is_not_in_range = !range.contains(&status_code);
let debug_request_format = self.debug_request_format();
let debug_body = DebugResponseBody(self);
assert!(
is_not_in_range,
"Expected status is not in range {}, received {status_code}, for request {debug_request_format}, with body {debug_body}",
format_status_code_range(range)
);
}
#[track_caller]
pub fn assert_status_ok(&self) {
self.assert_status(StatusCode::OK)
}
#[track_caller]
pub fn assert_status_not_ok(&self) {
self.assert_not_status(StatusCode::OK)
}
#[track_caller]
pub fn assert_status_see_other(&self) {
self.assert_status(StatusCode::SEE_OTHER)
}
#[track_caller]
pub fn assert_status_bad_request(&self) {
self.assert_status(StatusCode::BAD_REQUEST)
}
#[track_caller]
pub fn assert_status_not_found(&self) {
self.assert_status(StatusCode::NOT_FOUND)
}
#[track_caller]
pub fn assert_status_unauthorized(&self) {
self.assert_status(StatusCode::UNAUTHORIZED)
}
#[track_caller]
pub fn assert_status_forbidden(&self) {
self.assert_status(StatusCode::FORBIDDEN)
}
pub fn assert_status_conflict(&self) {
self.assert_status(StatusCode::CONFLICT)
}
#[track_caller]
pub fn assert_status_payload_too_large(&self) {
self.assert_status(StatusCode::PAYLOAD_TOO_LARGE)
}
#[track_caller]
pub fn assert_status_unprocessable_entity(&self) {
self.assert_status(StatusCode::UNPROCESSABLE_ENTITY)
}
#[track_caller]
pub fn assert_status_too_many_requests(&self) {
self.assert_status(StatusCode::TOO_MANY_REQUESTS)
}
#[track_caller]
pub fn assert_status_switching_protocols(&self) {
self.assert_status(StatusCode::SWITCHING_PROTOCOLS)
}
#[track_caller]
pub fn assert_status_internal_server_error(&self) {
self.assert_status(StatusCode::INTERNAL_SERVER_ERROR)
}
#[track_caller]
pub fn assert_status_service_unavailable(&self) {
self.assert_status(StatusCode::SERVICE_UNAVAILABLE)
}
fn debug_request_format(&self) -> RequestPathFormatter<'_> {
RequestPathFormatter::new(&self.method, self.full_request_url.as_str(), None)
}
}
impl From<TestResponse> for Bytes {
fn from(response: TestResponse) -> Self {
response.into_bytes()
}
}
#[cfg(test)]
mod test_assert_header {
use crate::TestServer;
use axum::http::HeaderMap;
use axum::routing::get;
use axum::Router;
async fn route_get_header() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("x-my-custom-header", "content".parse().unwrap());
headers
}
#[tokio::test]
async fn it_should_not_panic_if_contains_header_and_content_matches() {
let router = Router::new().route(&"/header", get(route_get_header));
let server = TestServer::new(router).unwrap();
server
.get(&"/header")
.await
.assert_header("x-my-custom-header", "content");
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_if_contains_header_and_content_does_not_match() {
let router = Router::new().route(&"/header", get(route_get_header));
let server = TestServer::new(router).unwrap();
server
.get(&"/header")
.await
.assert_header("x-my-custom-header", "different-content");
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_if_not_contains_header() {
let router = Router::new().route(&"/header", get(route_get_header));
let server = TestServer::new(router).unwrap();
server
.get(&"/header")
.await
.assert_header("x-custom-header-not-found", "content");
}
}
#[cfg(test)]
mod test_assert_contains_header {
use crate::TestServer;
use axum::http::HeaderMap;
use axum::routing::get;
use axum::Router;
async fn route_get_header() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("x-my-custom-header", "content".parse().unwrap());
headers
}
#[tokio::test]
async fn it_should_not_panic_if_contains_header() {
let router = Router::new().route(&"/header", get(route_get_header));
let server = TestServer::new(router).unwrap();
server
.get(&"/header")
.await
.assert_contains_header("x-my-custom-header");
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_if_not_contains_header() {
let router = Router::new().route(&"/header", get(route_get_header));
let server = TestServer::new(router).unwrap();
server
.get(&"/header")
.await
.assert_contains_header("x-custom-header-not-found");
}
}
#[cfg(test)]
mod test_assert_success {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
use http::StatusCode;
pub async fn route_get_pass() -> StatusCode {
StatusCode::OK
}
pub async fn route_get_fail() -> StatusCode {
StatusCode::SERVICE_UNAVAILABLE
}
#[tokio::test]
async fn it_should_pass_when_200() {
let router = Router::new()
.route(&"/pass", get(route_get_pass))
.route(&"/fail", get(route_get_fail));
let server = TestServer::new(router).unwrap();
let response = server.get(&"/pass").await;
response.assert_status_success()
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_when_not_200() {
let router = Router::new()
.route(&"/pass", get(route_get_pass))
.route(&"/fail", get(route_get_fail));
let server = TestServer::new(router).unwrap();
let response = server.get(&"/fail").expect_failure().await;
response.assert_status_success()
}
}
#[cfg(test)]
mod test_assert_failure {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
use http::StatusCode;
pub async fn route_get_pass() -> StatusCode {
StatusCode::OK
}
pub async fn route_get_fail() -> StatusCode {
StatusCode::SERVICE_UNAVAILABLE
}
#[tokio::test]
async fn it_should_pass_when_not_200() {
let router = Router::new()
.route(&"/pass", get(route_get_pass))
.route(&"/fail", get(route_get_fail));
let server = TestServer::new(router).unwrap();
let response = server.get(&"/fail").expect_failure().await;
response.assert_status_failure()
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_when_200() {
let router = Router::new()
.route(&"/pass", get(route_get_pass))
.route(&"/fail", get(route_get_fail));
let server = TestServer::new(router).unwrap();
let response = server.get(&"/pass").await;
response.assert_status_failure()
}
}
#[cfg(test)]
mod test_assert_status {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
use http::StatusCode;
pub async fn route_get_ok() -> StatusCode {
StatusCode::OK
}
#[tokio::test]
async fn it_should_pass_if_given_right_status_code() {
let router = Router::new().route(&"/ok", get(route_get_ok));
let server = TestServer::new(router).unwrap();
server.get(&"/ok").await.assert_status(StatusCode::OK);
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_when_status_code_does_not_match() {
let router = Router::new().route(&"/ok", get(route_get_ok));
let server = TestServer::new(router).unwrap();
server.get(&"/ok").await.assert_status(StatusCode::ACCEPTED);
}
}
#[cfg(test)]
mod test_assert_not_status {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
use http::StatusCode;
pub async fn route_get_ok() -> StatusCode {
StatusCode::OK
}
#[tokio::test]
async fn it_should_pass_if_status_code_does_not_match() {
let router = Router::new().route(&"/ok", get(route_get_ok));
let server = TestServer::new(router).unwrap();
server
.get(&"/ok")
.await
.assert_not_status(StatusCode::ACCEPTED);
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_if_status_code_matches() {
let router = Router::new().route(&"/ok", get(route_get_ok));
let server = TestServer::new(router).unwrap();
server.get(&"/ok").await.assert_not_status(StatusCode::OK);
}
}
#[cfg(test)]
mod test_assert_status_in_range {
use crate::TestServer;
use axum::routing::get;
use axum::routing::Router;
use http::StatusCode;
use std::ops::RangeFull;
#[tokio::test]
async fn it_should_be_true_when_within_int_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_in_range(200..299);
}
#[tokio::test]
async fn it_should_be_true_when_within_status_code_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_in_range(StatusCode::OK..StatusCode::IM_USED);
}
#[tokio::test]
#[should_panic]
async fn it_should_be_false_when_outside_int_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_in_range(200..299);
}
#[tokio::test]
#[should_panic]
async fn it_should_be_false_when_outside_status_code_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_in_range(StatusCode::OK..StatusCode::IM_USED);
}
#[tokio::test]
async fn it_should_be_true_when_within_inclusive_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_in_range(200..=299);
}
#[tokio::test]
#[should_panic]
async fn it_should_be_false_when_outside_inclusive_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_in_range(200..=299);
}
#[tokio::test]
async fn it_should_be_true_when_within_to_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_in_range(..299);
}
#[tokio::test]
#[should_panic]
async fn it_should_be_false_when_outside_to_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_in_range(..299);
}
#[tokio::test]
async fn it_should_be_true_when_within_to_inclusive_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_in_range(..=299);
}
#[tokio::test]
#[should_panic]
async fn it_should_be_false_when_outside_to_inclusive_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_in_range(..=299);
}
#[tokio::test]
async fn it_should_be_true_when_within_from_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_in_range(200..);
}
#[tokio::test]
#[should_panic]
async fn it_should_be_false_when_outside_from_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_in_range(500..);
}
#[tokio::test]
async fn it_should_be_true_for_rull_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_in_range::<RangeFull, StatusCode>(..);
}
}
#[cfg(test)]
mod test_assert_status_not_in_range {
use crate::TestServer;
use axum::routing::get;
use axum::routing::Router;
use http::StatusCode;
use std::ops::RangeFull;
#[tokio::test]
#[should_panic]
async fn it_should_be_false_when_within_int_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_not_in_range(200..299);
}
#[tokio::test]
#[should_panic]
async fn it_should_be_false_when_within_status_code_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_not_in_range(StatusCode::OK..StatusCode::IM_USED);
}
#[tokio::test]
async fn it_should_be_true_when_outside_int_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_not_in_range(200..299);
}
#[tokio::test]
async fn it_should_be_true_when_outside_status_code_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_not_in_range(StatusCode::OK..StatusCode::IM_USED);
}
#[tokio::test]
#[should_panic]
async fn it_should_be_false_when_within_inclusive_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_not_in_range(200..=299);
}
#[tokio::test]
async fn it_should_be_true_when_outside_inclusive_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_not_in_range(200..=299);
}
#[tokio::test]
#[should_panic]
async fn it_should_be_false_when_within_to_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_not_in_range(..299);
}
#[tokio::test]
async fn it_should_be_true_when_outside_to_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_not_in_range(..299);
}
#[tokio::test]
#[should_panic]
async fn it_should_be_false_when_within_to_inclusive_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_not_in_range(..=299);
}
#[tokio::test]
async fn it_should_be_true_when_outside_to_inclusive_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_not_in_range(..=299);
}
#[tokio::test]
#[should_panic]
async fn it_should_be_false_when_within_from_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_not_in_range(200..);
}
#[tokio::test]
async fn it_should_be_true_when_outside_from_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_not_in_range(500..);
}
#[tokio::test]
#[should_panic]
async fn it_should_be_false_for_rull_range() {
let app = Router::new().route(
&"/status",
get(|| async { StatusCode::NON_AUTHORITATIVE_INFORMATION }),
);
TestServer::new(app)
.unwrap()
.get(&"/status")
.await
.assert_status_not_in_range::<RangeFull, StatusCode>(..);
}
}
#[cfg(test)]
mod test_into_bytes {
use crate::TestServer;
use axum::routing::get;
use axum::Json;
use axum::Router;
use serde_json::json;
use serde_json::Value;
async fn route_get_json() -> Json<Value> {
Json(json!({
"message": "it works?"
}))
}
#[tokio::test]
async fn it_should_deserialize_into_json() {
let app = Router::new().route(&"/json", get(route_get_json));
let server = TestServer::new(app).unwrap();
let bytes = server.get(&"/json").await.into_bytes();
let text = String::from_utf8_lossy(&bytes);
assert_eq!(text, r#"{"message":"it works?"}"#);
}
}
#[cfg(test)]
mod test_content_type {
use crate::TestServer;
use axum::routing::get;
use axum::Json;
use axum::Router;
use serde::Deserialize;
use serde::Serialize;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ExampleResponse {
name: String,
age: u32,
}
#[tokio::test]
async fn it_should_retrieve_json_content_type_for_json() {
let app = Router::new().route(
&"/json",
get(|| async {
Json(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}),
);
let server = TestServer::new(app).unwrap();
let content_type = server.get(&"/json").await.content_type();
assert_eq!(content_type, "application/json");
}
#[cfg(feature = "yaml")]
#[tokio::test]
async fn it_should_retrieve_yaml_content_type_for_yaml() {
use axum_yaml::Yaml;
let app = Router::new().route(
&"/yaml",
get(|| async {
Yaml(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}),
);
let server = TestServer::new(app).unwrap();
let content_type = server.get(&"/yaml").await.content_type();
assert_eq!(content_type, "application/yaml");
}
}
#[cfg(test)]
mod test_json {
use crate::TestServer;
use axum::routing::get;
use axum::Json;
use axum::Router;
use serde::Deserialize;
use serde::Serialize;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ExampleResponse {
name: String,
age: u32,
}
async fn route_get_json() -> Json<ExampleResponse> {
Json(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}
#[tokio::test]
async fn it_should_deserialize_into_json() {
let app = Router::new().route(&"/json", get(route_get_json));
let server = TestServer::new(app).unwrap();
let response = server.get(&"/json").await.json::<ExampleResponse>();
assert_eq!(
response,
ExampleResponse {
name: "Joe".to_string(),
age: 20,
}
);
}
}
#[cfg(feature = "yaml")]
#[cfg(test)]
mod test_yaml {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
use axum_yaml::Yaml;
use serde::Deserialize;
use serde::Serialize;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ExampleResponse {
name: String,
age: u32,
}
async fn route_get_yaml() -> Yaml<ExampleResponse> {
Yaml(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}
#[tokio::test]
async fn it_should_deserialize_into_yaml() {
let app = Router::new().route(&"/yaml", get(route_get_yaml));
let server = TestServer::new(app).unwrap();
let response = server.get(&"/yaml").await.yaml::<ExampleResponse>();
assert_eq!(
response,
ExampleResponse {
name: "Joe".to_string(),
age: 20,
}
);
}
}
#[cfg(feature = "msgpack")]
#[cfg(test)]
mod test_msgpack {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
use axum_msgpack::MsgPack;
use serde::Deserialize;
use serde::Serialize;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ExampleResponse {
name: String,
age: u32,
}
async fn route_get_msgpack() -> MsgPack<ExampleResponse> {
MsgPack(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}
#[tokio::test]
async fn it_should_deserialize_into_msgpack() {
let app = Router::new().route(&"/msgpack", get(route_get_msgpack));
let server = TestServer::new(app).unwrap();
let response = server.get(&"/msgpack").await.msgpack::<ExampleResponse>();
assert_eq!(
response,
ExampleResponse {
name: "Joe".to_string(),
age: 20,
}
);
}
}
#[cfg(test)]
mod test_form {
use crate::TestServer;
use axum::routing::get;
use axum::Form;
use axum::Router;
use serde::Deserialize;
use serde::Serialize;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ExampleResponse {
name: String,
age: u32,
}
async fn route_get_form() -> Form<ExampleResponse> {
Form(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}
#[tokio::test]
async fn it_should_deserialize_into_form() {
let app = Router::new().route(&"/form", get(route_get_form));
let server = TestServer::new(app).unwrap();
let response = server.get(&"/form").await.form::<ExampleResponse>();
assert_eq!(
response,
ExampleResponse {
name: "Joe".to_string(),
age: 20,
}
);
}
}
#[cfg(test)]
mod test_from {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
use bytes::Bytes;
#[tokio::test]
async fn it_should_turn_into_response_bytes() {
let app = Router::new().route(&"/text", get(|| async { "This is some example text" }));
let server = TestServer::new(app).unwrap();
let response = server.get(&"/text").await;
let bytes: Bytes = response.into();
let text = String::from_utf8_lossy(&bytes);
assert_eq!(text, "This is some example text");
}
}
#[cfg(test)]
mod test_assert_text {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
fn new_test_server() -> TestServer {
async fn route_get_text() -> &'static str {
"This is some example text"
}
let app = Router::new().route(&"/text", get(route_get_text));
TestServer::new(app).unwrap()
}
#[tokio::test]
async fn it_should_match_whole_text() {
let server = new_test_server();
server
.get(&"/text")
.await
.assert_text("This is some example text");
}
#[tokio::test]
#[should_panic]
async fn it_should_not_match_partial_text() {
let server = new_test_server();
server.get(&"/text").await.assert_text("some example");
}
#[tokio::test]
#[should_panic]
async fn it_should_not_match_different_text() {
let server = new_test_server();
server.get(&"/text").await.assert_text("🦊");
}
}
#[cfg(test)]
mod test_assert_text_contains {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
fn new_test_server() -> TestServer {
async fn route_get_text() -> &'static str {
"This is some example text"
}
let app = Router::new().route(&"/text", get(route_get_text));
TestServer::new(app).unwrap()
}
#[tokio::test]
async fn it_should_match_whole_text() {
let server = new_test_server();
server
.get(&"/text")
.await
.assert_text_contains("This is some example text");
}
#[tokio::test]
async fn it_should_match_partial_text() {
let server = new_test_server();
server
.get(&"/text")
.await
.assert_text_contains("some example");
}
#[tokio::test]
#[should_panic]
async fn it_should_not_match_different_text() {
let server = new_test_server();
server.get(&"/text").await.assert_text_contains("🦊");
}
}
#[cfg(test)]
mod test_assert_text_from_file {
use crate::TestServer;
use axum::routing::get;
use axum::routing::Router;
#[tokio::test]
async fn it_should_match_from_file() {
let app = Router::new().route(&"/text", get(|| async { "hello!" }));
let server = TestServer::new(app).unwrap();
server
.get(&"/text")
.await
.assert_text_from_file("files/example.txt");
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_when_not_match_the_file() {
let app = Router::new().route(&"/text", get(|| async { "🦊" }));
let server = TestServer::new(app).unwrap();
server
.get(&"/text")
.await
.assert_text_from_file("files/example.txt");
}
}
#[cfg(test)]
mod test_assert_json {
use crate::TestServer;
use axum::routing::get;
use axum::Form;
use axum::Json;
use axum::Router;
use serde::Deserialize;
use serde::Serialize;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ExampleResponse {
name: String,
age: u32,
}
async fn route_get_form() -> Form<ExampleResponse> {
Form(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}
async fn route_get_json() -> Json<ExampleResponse> {
Json(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}
#[tokio::test]
async fn it_should_match_json_returned() {
let app = Router::new().route(&"/json", get(route_get_json));
let server = TestServer::new(app).unwrap();
server.get(&"/json").await.assert_json(&ExampleResponse {
name: "Joe".to_string(),
age: 20,
});
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_if_response_is_different() {
let app = Router::new().route(&"/json", get(route_get_json));
let server = TestServer::new(app).unwrap();
server.get(&"/json").await.assert_json(&ExampleResponse {
name: "Julia".to_string(),
age: 25,
});
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_if_response_is_form() {
let app = Router::new().route(&"/form", get(route_get_form));
let server = TestServer::new(app).unwrap();
server.get(&"/form").await.assert_json(&ExampleResponse {
name: "Joe".to_string(),
age: 20,
});
}
}
#[cfg(test)]
mod test_assert_json_contains {
use crate::TestServer;
use axum::routing::get;
use axum::Form;
use axum::Json;
use axum::Router;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
use std::time::Instant;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ExampleResponse {
time: u64,
name: String,
age: u32,
}
async fn route_get_form() -> Form<ExampleResponse> {
Form(ExampleResponse {
time: Instant::now().elapsed().as_millis() as u64,
name: "Joe".to_string(),
age: 20,
})
}
async fn route_get_json() -> Json<ExampleResponse> {
Json(ExampleResponse {
time: Instant::now().elapsed().as_millis() as u64,
name: "Joe".to_string(),
age: 20,
})
}
#[tokio::test]
async fn it_should_match_subset_of_json_returned() {
let app = Router::new().route(&"/json", get(route_get_json));
let server = TestServer::new(app).unwrap();
server.get(&"/json").await.assert_json_contains(&json!({
"name": "Joe",
"age": 20,
}));
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_if_response_is_different() {
let app = Router::new().route(&"/json", get(route_get_json));
let server = TestServer::new(app).unwrap();
server
.get(&"/json")
.await
.assert_json_contains(&ExampleResponse {
time: 1234,
name: "Julia".to_string(),
age: 25,
});
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_if_response_is_form() {
let app = Router::new().route(&"/form", get(route_get_form));
let server = TestServer::new(app).unwrap();
server.get(&"/form").await.assert_json_contains(&json!({
"name": "Joe",
"age": 20,
}));
}
}
#[cfg(test)]
mod test_assert_json_from_file {
use crate::TestServer;
use axum::routing::get;
use axum::routing::Router;
use axum::Form;
use axum::Json;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
#[tokio::test]
async fn it_should_match_json_from_file() {
let app = Router::new().route(
&"/json",
get(|| async {
Json(json!(
{
"name": "Joe",
"age": 20,
}
))
}),
);
let server = TestServer::new(app).unwrap();
server
.get(&"/json")
.await
.assert_json_from_file("files/example.json");
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_when_not_match_the_file() {
let app = Router::new().route(
&"/json",
get(|| async {
Json(json!(
{
"name": "Julia",
"age": 25,
}
))
}),
);
let server = TestServer::new(app).unwrap();
server
.get(&"/json")
.await
.assert_json_from_file("files/example.json");
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_when_content_type_does_not_match() {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ExampleResponse {
name: String,
age: u32,
}
let app = Router::new().route(
&"/form",
get(|| async {
Form(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}),
);
let server = TestServer::new(app).unwrap();
server
.get(&"/form")
.await
.assert_json_from_file("files/example.json");
}
}
#[cfg(feature = "yaml")]
#[cfg(test)]
mod test_assert_yaml {
use crate::TestServer;
use axum::routing::get;
use axum::Form;
use axum::Router;
use axum_yaml::Yaml;
use serde::Deserialize;
use serde::Serialize;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ExampleResponse {
name: String,
age: u32,
}
async fn route_get_form() -> Form<ExampleResponse> {
Form(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}
async fn route_get_yaml() -> Yaml<ExampleResponse> {
Yaml(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}
#[tokio::test]
async fn it_should_match_yaml_returned() {
let app = Router::new().route(&"/yaml", get(route_get_yaml));
let server = TestServer::new(app).unwrap();
server.get(&"/yaml").await.assert_yaml(&ExampleResponse {
name: "Joe".to_string(),
age: 20,
});
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_if_response_is_different() {
let app = Router::new().route(&"/yaml", get(route_get_yaml));
let server = TestServer::new(app).unwrap();
server.get(&"/yaml").await.assert_yaml(&ExampleResponse {
name: "Julia".to_string(),
age: 25,
});
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_if_response_is_form() {
let app = Router::new().route(&"/form", get(route_get_form));
let server = TestServer::new(app).unwrap();
server.get(&"/form").await.assert_yaml(&ExampleResponse {
name: "Joe".to_string(),
age: 20,
});
}
}
#[cfg(feature = "yaml")]
#[cfg(test)]
mod test_assert_yaml_from_file {
use crate::TestServer;
use axum::routing::get;
use axum::routing::Router;
use axum::Form;
use axum_yaml::Yaml;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
#[tokio::test]
async fn it_should_match_yaml_from_file() {
let app = Router::new().route(
&"/yaml",
get(|| async {
Yaml(json!(
{
"name": "Joe",
"age": 20,
}
))
}),
);
let server = TestServer::new(app).unwrap();
server
.get(&"/yaml")
.await
.assert_yaml_from_file("files/example.yaml");
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_when_not_match_the_file() {
let app = Router::new().route(
&"/yaml",
get(|| async {
Yaml(json!(
{
"name": "Julia",
"age": 25,
}
))
}),
);
let server = TestServer::new(app).unwrap();
server
.get(&"/yaml")
.await
.assert_yaml_from_file("files/example.yaml");
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_when_content_type_does_not_match() {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ExampleResponse {
name: String,
age: u32,
}
let app = Router::new().route(
&"/form",
get(|| async {
Form(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}),
);
let server = TestServer::new(app).unwrap();
server
.get(&"/form")
.await
.assert_yaml_from_file("files/example.yaml");
}
}
#[cfg(test)]
mod test_assert_form {
use crate::TestServer;
use axum::routing::get;
use axum::Form;
use axum::Json;
use axum::Router;
use serde::Deserialize;
use serde::Serialize;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ExampleResponse {
name: String,
age: u32,
}
async fn route_get_form() -> Form<ExampleResponse> {
Form(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}
async fn route_get_json() -> Json<ExampleResponse> {
Json(ExampleResponse {
name: "Joe".to_string(),
age: 20,
})
}
#[tokio::test]
async fn it_should_match_form_returned() {
let app = Router::new().route(&"/form", get(route_get_form));
let server = TestServer::new(app).unwrap();
server.get(&"/form").await.assert_form(&ExampleResponse {
name: "Joe".to_string(),
age: 20,
});
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_if_response_is_different() {
let app = Router::new().route(&"/form", get(route_get_form));
let server = TestServer::new(app).unwrap();
server.get(&"/form").await.assert_form(&ExampleResponse {
name: "Julia".to_string(),
age: 25,
});
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_if_response_is_json() {
let app = Router::new().route(&"/json", get(route_get_json));
let server = TestServer::new(app).unwrap();
server.get(&"/json").await.assert_form(&ExampleResponse {
name: "Joe".to_string(),
age: 20,
});
}
}
#[cfg(test)]
mod test_text {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
#[tokio::test]
async fn it_should_deserialize_into_text() {
async fn route_get_text() -> String {
"hello!".to_string()
}
let app = Router::new().route(&"/text", get(route_get_text));
let server = TestServer::new(app).unwrap();
let response = server.get(&"/text").await.text();
assert_eq!(response, "hello!");
}
}
#[cfg(feature = "ws")]
#[cfg(test)]
mod test_into_websocket {
use crate::TestServer;
use axum::extract::ws::WebSocket;
use axum::extract::WebSocketUpgrade;
use axum::response::Response;
use axum::routing::get;
use axum::Router;
fn new_test_router() -> Router {
pub async fn route_get_websocket(ws: WebSocketUpgrade) -> Response {
async fn handle_ping_pong(mut socket: WebSocket) {
while let Some(_) = socket.recv().await {
}
}
ws.on_upgrade(move |socket| handle_ping_pong(socket))
}
let app = Router::new().route(&"/ws", get(route_get_websocket));
app
}
#[tokio::test]
async fn it_should_upgrade_on_http_transport() {
let router = new_test_router();
let server = TestServer::builder()
.http_transport()
.build(router)
.unwrap();
let _ = server.get_websocket(&"/ws").await.into_websocket().await;
assert!(true);
}
#[tokio::test]
#[should_panic]
async fn it_should_fail_to_upgrade_on_mock_transport() {
let router = new_test_router();
let server = TestServer::builder()
.mock_transport()
.build(router)
.unwrap();
let _ = server.get_websocket(&"/ws").await.into_websocket().await;
}
}