use anyhow::anyhow;
use anyhow::Context;
use anyhow::Error as AnyhowError;
use anyhow::Result;
use auto_future::AutoFuture;
use axum::body::Body;
use bytes::Bytes;
use cookie::time::OffsetDateTime;
use cookie::Cookie;
use cookie::CookieJar;
use http::header;
use http::header::SET_COOKIE;
use http::HeaderName;
use http::HeaderValue;
use http::Method;
use http::Request;
use http_body_util::BodyExt;
use serde::Serialize;
use std::fmt::Debug;
use std::fmt::Display;
use std::fs::read;
use std::fs::read_to_string;
use std::fs::File;
use std::future::IntoFuture;
use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
use std::sync::Mutex;
use url::Url;
use crate::internals::ExpectedState;
use crate::internals::QueryParamsStore;
use crate::internals::RequestPathFormatter;
use crate::multipart::MultipartForm;
use crate::transport_layer::TransportLayer;
use crate::ServerSharedState;
use crate::TestResponse;
mod test_request_config;
pub(crate) use self::test_request_config::*;
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
pub struct TestRequest {
config: TestRequestConfig,
server_state: Arc<Mutex<ServerSharedState>>,
transport: Arc<Box<dyn TransportLayer>>,
body: Option<Body>,
expected_state: ExpectedState,
}
impl TestRequest {
pub(crate) fn new(
server_state: Arc<Mutex<ServerSharedState>>,
transport: Arc<Box<dyn TransportLayer>>,
config: TestRequestConfig,
) -> Self {
let expected_state = config.expected_state;
Self {
config,
server_state,
transport,
body: None,
expected_state,
}
}
pub fn json<J>(self, body: &J) -> Self
where
J: ?Sized + Serialize,
{
let body_bytes =
serde_json::to_vec(body).expect("It should serialize the content into Json");
self.bytes(body_bytes.into())
.content_type(mime::APPLICATION_JSON.essence_str())
}
pub fn json_from_file<P>(self, path: P) -> Self
where
P: AsRef<Path>,
{
let path_ref = path.as_ref();
let file = File::open(path_ref)
.with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
.unwrap();
let reader = BufReader::new(file);
let payload = serde_json::from_reader::<_, serde_json::Value>(reader)
.with_context(|| {
format!(
"Failed to deserialize file '{}' as Json",
path_ref.display()
)
})
.unwrap();
self.json(&payload)
}
#[cfg(feature = "yaml")]
pub fn yaml<Y>(self, body: &Y) -> Self
where
Y: ?Sized + Serialize,
{
let body = serde_yaml::to_string(body).expect("It should serialize the content into Yaml");
self.bytes(body.into_bytes().into())
.content_type("application/yaml")
}
#[cfg(feature = "yaml")]
pub fn yaml_from_file<P>(self, path: P) -> Self
where
P: AsRef<Path>,
{
let path_ref = path.as_ref();
let file = File::open(path_ref)
.with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
.unwrap();
let reader = BufReader::new(file);
let payload = serde_yaml::from_reader::<_, serde_yaml::Value>(reader)
.with_context(|| {
format!(
"Failed to deserialize file '{}' as Yaml",
path_ref.display()
)
})
.unwrap();
self.yaml(&payload)
}
#[cfg(feature = "msgpack")]
pub fn msgpack<M>(self, body: &M) -> Self
where
M: ?Sized + Serialize,
{
let body_bytes =
::rmp_serde::to_vec(body).expect("It should serialize the content into MsgPack");
self.bytes(body_bytes.into())
.content_type("application/msgpack")
}
pub fn form<F>(self, body: &F) -> Self
where
F: ?Sized + Serialize,
{
let body_text =
serde_urlencoded::to_string(body).expect("It should serialize the content into a Form");
self.bytes(body_text.into())
.content_type(mime::APPLICATION_WWW_FORM_URLENCODED.essence_str())
}
pub fn multipart(mut self, multipart: MultipartForm) -> Self {
self.config.content_type = Some(multipart.content_type());
self.body = Some(multipart.into());
self
}
pub fn text<T>(self, raw_text: T) -> Self
where
T: Display,
{
let body_text = format!("{}", raw_text);
self.bytes(body_text.into())
.content_type(mime::TEXT_PLAIN.essence_str())
}
pub fn text_from_file<P>(self, path: P) -> Self
where
P: AsRef<Path>,
{
let path_ref = path.as_ref();
let payload = read_to_string(path_ref)
.with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
.unwrap();
self.text(payload)
}
pub fn bytes(mut self, body_bytes: Bytes) -> Self {
let body: Body = body_bytes.into();
self.body = Some(body);
self
}
pub fn bytes_from_file<P>(self, path: P) -> Self
where
P: AsRef<Path>,
{
let path_ref = path.as_ref();
let payload = read(path_ref)
.with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
.unwrap();
self.bytes(payload.into())
}
pub fn content_type(mut self, content_type: &str) -> Self {
self.config.content_type = Some(content_type.to_string());
self
}
pub fn add_cookie(mut self, cookie: Cookie<'_>) -> Self {
self.config.cookies.add(cookie.into_owned());
self
}
pub fn add_cookies(mut self, cookies: CookieJar) -> Self {
for cookie in cookies.iter() {
self.config.cookies.add(cookie.clone());
}
self
}
pub fn clear_cookies(mut self) -> Self {
self.config.cookies = CookieJar::new();
self
}
pub fn save_cookies(mut self) -> Self {
self.config.is_saving_cookies = true;
self
}
pub fn do_not_save_cookies(mut self) -> Self {
self.config.is_saving_cookies = false;
self
}
pub fn add_query_param<V>(self, key: &str, value: V) -> Self
where
V: Serialize,
{
self.add_query_params(&[(key, value)])
}
pub fn add_query_params<V>(mut self, query_params: V) -> Self
where
V: Serialize,
{
self.config
.query_params
.add(query_params)
.with_context(|| {
format!(
"It should serialize query parameters, for request {}",
self.debug_request_format()
)
})
.unwrap();
self
}
pub fn add_raw_query_param(mut self, query_param: &str) -> Self {
self.config.query_params.add_raw(query_param.to_string());
self
}
pub fn clear_query_params(mut self) -> Self {
self.config.query_params.clear();
self
}
pub fn add_header<N, V>(mut self, name: N, value: V) -> Self
where
N: TryInto<HeaderName>,
N::Error: Debug,
V: TryInto<HeaderValue>,
V::Error: Debug,
{
let header_name: HeaderName = name
.try_into()
.expect("Failed to convert header name to HeaderName");
let header_value: HeaderValue = value
.try_into()
.expect("Failed to convert header vlue to HeaderValue");
self.config.headers.push((header_name, header_value));
self
}
pub fn authorization<T>(self, authorization_header: T) -> Self
where
T: AsRef<str>,
{
let authorization_header_value = HeaderValue::from_str(authorization_header.as_ref())
.expect("Cannot build Authorization HeaderValue from token");
self.add_header(header::AUTHORIZATION, authorization_header_value)
}
pub fn authorization_bearer<T>(self, authorization_bearer_token: T) -> Self
where
T: Display,
{
let authorization_bearer_header_str = format!("Bearer {authorization_bearer_token}");
self.authorization(authorization_bearer_header_str)
}
pub fn clear_headers(mut self) -> Self {
self.config.headers = vec![];
self
}
pub fn scheme(mut self, scheme: &str) -> Self {
self.config
.full_request_url
.set_scheme(scheme)
.map_err(|_| anyhow!("Scheme '{scheme}' cannot be set to request"))
.unwrap();
self
}
pub fn expect_success(self) -> Self {
self.expect_state(ExpectedState::Success)
}
pub fn expect_failure(self) -> Self {
self.expect_state(ExpectedState::Failure)
}
fn expect_state(mut self, expected_state: ExpectedState) -> Self {
self.expected_state = expected_state;
self
}
async fn send(self) -> Result<TestResponse> {
let debug_request_format = self.debug_request_format().to_string();
let method = self.config.method;
let expected_state = self.expected_state;
let save_cookies = self.config.is_saving_cookies;
let body = self.body.unwrap_or(Body::empty());
let url =
Self::build_url_query_params(self.config.full_request_url, &self.config.query_params);
let request = Self::build_request(
method.clone(),
&url,
body,
self.config.content_type,
self.config.cookies,
self.config.headers,
&debug_request_format,
)?;
#[allow(unused_mut)] let mut http_response = self.transport.send(request).await?;
#[cfg(feature = "ws")]
let websockets = {
let maybe_on_upgrade = http_response
.extensions_mut()
.remove::<hyper::upgrade::OnUpgrade>();
let transport_type = self.transport.transport_layer_type();
crate::internals::TestResponseWebSocket {
maybe_on_upgrade,
transport_type,
}
};
let (parts, response_body) = http_response.into_parts();
let response_bytes = response_body.collect().await?.to_bytes();
if save_cookies {
let cookie_headers = parts.headers.get_all(SET_COOKIE).into_iter();
ServerSharedState::add_cookies_by_header(&self.server_state, cookie_headers)?;
}
let test_response = TestResponse::new(
method,
url,
parts,
response_bytes,
#[cfg(feature = "ws")]
websockets,
);
match expected_state {
ExpectedState::Success => test_response.assert_status_success(),
ExpectedState::Failure => test_response.assert_status_failure(),
ExpectedState::None => {}
}
Ok(test_response)
}
fn build_url_query_params(mut url: Url, query_params: &QueryParamsStore) -> Url {
if query_params.has_content() {
url.set_query(Some(&query_params.to_string()));
}
url
}
fn build_request(
method: Method,
url: &Url,
body: Body,
content_type: Option<String>,
cookies: CookieJar,
headers: Vec<(HeaderName, HeaderValue)>,
debug_request_format: &str,
) -> Result<Request<Body>> {
let mut request_builder = Request::builder().uri(url.as_str()).method(method);
if let Some(content_type) = content_type {
let (header_key, header_value) =
build_content_type_header(&content_type, debug_request_format)?;
request_builder = request_builder.header(header_key, header_value);
}
let now = OffsetDateTime::now_utc();
for cookie in cookies.iter() {
let expired = cookie
.expires_datetime()
.map(|expires| expires <= now)
.unwrap_or(false);
if !expired {
let cookie_raw = cookie.stripped().to_string();
let header_value = HeaderValue::from_str(&cookie_raw)?;
request_builder = request_builder.header(header::COOKIE, header_value);
}
}
for (header_name, header_value) in headers {
request_builder = request_builder.header(header_name, header_value);
}
let request = request_builder.body(body).with_context(|| {
format!("Expect valid hyper Request to be built, for request {debug_request_format}")
})?;
Ok(request)
}
fn debug_request_format(&self) -> RequestPathFormatter<'_> {
RequestPathFormatter::new(
&self.config.method,
self.config.full_request_url.as_str(),
Some(&self.config.query_params),
)
}
}
impl TryFrom<TestRequest> for Request<Body> {
type Error = AnyhowError;
fn try_from(test_request: TestRequest) -> Result<Request<Body>> {
let debug_request_format = test_request.debug_request_format().to_string();
let url = TestRequest::build_url_query_params(
test_request.config.full_request_url,
&test_request.config.query_params,
);
let body = test_request.body.unwrap_or(Body::empty());
TestRequest::build_request(
test_request.config.method,
&url,
body,
test_request.config.content_type,
test_request.config.cookies,
test_request.config.headers,
&debug_request_format,
)
}
}
impl IntoFuture for TestRequest {
type Output = TestResponse;
type IntoFuture = AutoFuture<TestResponse>;
fn into_future(self) -> Self::IntoFuture {
AutoFuture::new(async { self.send().await.context("Sending request failed").unwrap() })
}
}
fn build_content_type_header(
content_type: &str,
debug_request_format: &str,
) -> Result<(HeaderName, HeaderValue)> {
let header_value = HeaderValue::from_str(content_type).with_context(|| {
format!(
"Failed to store header content type '{content_type}', for request {debug_request_format}"
)
})?;
Ok((header::CONTENT_TYPE, header_value))
}
#[cfg(test)]
mod test_content_type {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
use http::header::CONTENT_TYPE;
use http::HeaderMap;
async fn get_content_type(headers: HeaderMap) -> String {
headers
.get(CONTENT_TYPE)
.map(|h| h.to_str().unwrap().to_string())
.unwrap_or_else(|| "".to_string())
}
#[tokio::test]
async fn it_should_not_set_a_content_type_by_default() {
let app = Router::new().route("/content_type", get(get_content_type));
let server = TestServer::new(app).expect("Should create test server");
let text = server.get(&"/content_type").await.text();
assert_eq!(text, "");
}
#[tokio::test]
async fn it_should_override_server_content_type_when_present() {
let app = Router::new().route("/content_type", get(get_content_type));
let server = TestServer::builder()
.default_content_type("text/plain")
.build(app)
.expect("Should create test server");
let text = server
.get(&"/content_type")
.content_type(&"application/json")
.await
.text();
assert_eq!(text, "application/json");
}
#[tokio::test]
async fn it_should_set_content_type_when_present() {
let app = Router::new().route("/content_type", get(get_content_type));
let server = TestServer::new(app).expect("Should create test server");
let text = server
.get(&"/content_type")
.content_type(&"application/custom")
.await
.text();
assert_eq!(text, "application/custom");
}
}
#[cfg(test)]
mod test_json {
use crate::TestServer;
use axum::routing::post;
use axum::Json;
use axum::Router;
use http::header::CONTENT_TYPE;
use http::HeaderMap;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
#[tokio::test]
async fn it_should_pass_json_up_to_be_read() {
#[derive(Deserialize, Serialize)]
struct TestJson {
name: String,
age: u32,
pets: Option<String>,
}
let app = Router::new().route(
"/json",
post(|Json(json): Json<TestJson>| async move {
format!(
"json: {}, {}, {}",
json.name,
json.age,
json.pets.unwrap_or_else(|| "pandas".to_string())
)
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/json")
.json(&TestJson {
name: "Joe".to_string(),
age: 20,
pets: Some("foxes".to_string()),
})
.await
.text();
assert_eq!(text, "json: Joe, 20, foxes");
}
#[tokio::test]
async fn it_should_pass_json_content_type_for_json() {
let app = Router::new().route(
"/content_type",
post(|headers: HeaderMap| async move {
headers
.get(CONTENT_TYPE)
.map(|h| h.to_str().unwrap().to_string())
.unwrap_or_else(|| "".to_string())
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server.post(&"/content_type").json(&json!({})).await.text();
assert_eq!(text, "application/json");
}
}
#[cfg(test)]
mod test_json_from_file {
use crate::TestServer;
use axum::routing::post;
use axum::Json;
use axum::Router;
use http::header::CONTENT_TYPE;
use http::HeaderMap;
use serde::Deserialize;
use serde::Serialize;
#[tokio::test]
async fn it_should_pass_json_up_to_be_read() {
#[derive(Deserialize, Serialize)]
struct TestJson {
name: String,
age: u32,
}
let app = Router::new().route(
"/json",
post(|Json(json): Json<TestJson>| async move {
format!("json: {}, {}", json.name, json.age,)
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/json")
.json_from_file(&"files/example.json")
.await
.text();
assert_eq!(text, "json: Joe, 20");
}
#[tokio::test]
async fn it_should_pass_json_content_type_for_json() {
let app = Router::new().route(
"/content_type",
post(|headers: HeaderMap| async move {
headers
.get(CONTENT_TYPE)
.map(|h| h.to_str().unwrap().to_string())
.unwrap_or_else(|| "".to_string())
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/content_type")
.json_from_file(&"files/example.json")
.await
.text();
assert_eq!(text, "application/json");
}
}
#[cfg(feature = "yaml")]
#[cfg(test)]
mod test_yaml {
use crate::TestServer;
use axum::routing::post;
use axum::Router;
use axum_yaml::Yaml;
use http::header::CONTENT_TYPE;
use http::HeaderMap;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
#[tokio::test]
async fn it_should_pass_yaml_up_to_be_read() {
#[derive(Deserialize, Serialize)]
struct TestYaml {
name: String,
age: u32,
pets: Option<String>,
}
let app = Router::new().route(
"/yaml",
post(|Yaml(yaml): Yaml<TestYaml>| async move {
format!(
"yaml: {}, {}, {}",
yaml.name,
yaml.age,
yaml.pets.unwrap_or_else(|| "pandas".to_string())
)
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/yaml")
.yaml(&TestYaml {
name: "Joe".to_string(),
age: 20,
pets: Some("foxes".to_string()),
})
.await
.text();
assert_eq!(text, "yaml: Joe, 20, foxes");
}
#[tokio::test]
async fn it_should_pass_yaml_content_type_for_yaml() {
let app = Router::new().route(
"/content_type",
post(|headers: HeaderMap| async move {
headers
.get(CONTENT_TYPE)
.map(|h| h.to_str().unwrap().to_string())
.unwrap_or_else(|| "".to_string())
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server.post(&"/content_type").yaml(&json!({})).await.text();
assert_eq!(text, "application/yaml");
}
}
#[cfg(feature = "yaml")]
#[cfg(test)]
mod test_yaml_from_file {
use crate::TestServer;
use axum::routing::post;
use axum::Router;
use axum_yaml::Yaml;
use http::header::CONTENT_TYPE;
use http::HeaderMap;
use serde::Deserialize;
use serde::Serialize;
#[tokio::test]
async fn it_should_pass_yaml_up_to_be_read() {
#[derive(Deserialize, Serialize)]
struct TestYaml {
name: String,
age: u32,
}
let app = Router::new().route(
"/yaml",
post(|Yaml(yaml): Yaml<TestYaml>| async move {
format!("yaml: {}, {}", yaml.name, yaml.age,)
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/yaml")
.yaml_from_file(&"files/example.yaml")
.await
.text();
assert_eq!(text, "yaml: Joe, 20");
}
#[tokio::test]
async fn it_should_pass_yaml_content_type_for_yaml() {
let app = Router::new().route(
"/content_type",
post(|headers: HeaderMap| async move {
headers
.get(CONTENT_TYPE)
.map(|h| h.to_str().unwrap().to_string())
.unwrap_or_else(|| "".to_string())
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/content_type")
.yaml_from_file(&"files/example.yaml")
.await
.text();
assert_eq!(text, "application/yaml");
}
}
#[cfg(feature = "msgpack")]
#[cfg(test)]
mod test_msgpack {
use crate::TestServer;
use axum::routing::post;
use axum::Router;
use axum_msgpack::MsgPack;
use http::header::CONTENT_TYPE;
use http::HeaderMap;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
#[tokio::test]
async fn it_should_pass_msgpack_up_to_be_read() {
#[derive(Deserialize, Serialize)]
struct TestMsgPack {
name: String,
age: u32,
pets: Option<String>,
}
async fn get_msgpack(MsgPack(msgpack): MsgPack<TestMsgPack>) -> String {
format!(
"yaml: {}, {}, {}",
msgpack.name,
msgpack.age,
msgpack.pets.unwrap_or_else(|| "pandas".to_string())
)
}
let app = Router::new().route("/msgpack", post(get_msgpack));
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/msgpack")
.msgpack(&TestMsgPack {
name: "Joe".to_string(),
age: 20,
pets: Some("foxes".to_string()),
})
.await
.text();
assert_eq!(text, "yaml: Joe, 20, foxes");
}
#[tokio::test]
async fn it_should_pass_msgpck_content_type_for_msgpack() {
async fn get_content_type(headers: HeaderMap) -> String {
headers
.get(CONTENT_TYPE)
.map(|h| h.to_str().unwrap().to_string())
.unwrap_or_else(|| "".to_string())
}
let app = Router::new().route("/content_type", post(get_content_type));
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/content_type")
.msgpack(&json!({}))
.await
.text();
assert_eq!(text, "application/msgpack");
}
}
#[cfg(test)]
mod test_form {
use crate::TestServer;
use axum::routing::post;
use axum::Form;
use axum::Router;
use http::header::CONTENT_TYPE;
use http::HeaderMap;
use serde::Deserialize;
use serde::Serialize;
#[tokio::test]
async fn it_should_pass_form_up_to_be_read() {
#[derive(Deserialize, Serialize)]
struct TestForm {
name: String,
age: u32,
pets: Option<String>,
}
async fn get_form(Form(form): Form<TestForm>) -> String {
format!(
"form: {}, {}, {}",
form.name,
form.age,
form.pets.unwrap_or_else(|| "pandas".to_string())
)
}
let app = Router::new().route("/form", post(get_form));
let server = TestServer::new(app).expect("Should create test server");
server
.post(&"/form")
.form(&TestForm {
name: "Joe".to_string(),
age: 20,
pets: Some("foxes".to_string()),
})
.await
.assert_text("form: Joe, 20, foxes");
}
#[tokio::test]
async fn it_should_pass_form_content_type_for_form() {
async fn get_content_type(headers: HeaderMap) -> String {
headers
.get(CONTENT_TYPE)
.map(|h| h.to_str().unwrap().to_string())
.unwrap_or_else(|| "".to_string())
}
let app = Router::new().route("/content_type", post(get_content_type));
let server = TestServer::new(app).expect("Should create test server");
#[derive(Serialize)]
struct MyForm {
message: String,
}
server
.post(&"/content_type")
.form(&MyForm {
message: "hello".to_string(),
})
.await
.assert_text("application/x-www-form-urlencoded");
}
}
#[cfg(test)]
mod test_bytes {
use crate::TestServer;
use axum::extract::Request;
use axum::routing::post;
use axum::Router;
use http::header::CONTENT_TYPE;
use http::HeaderMap;
use http_body_util::BodyExt;
#[tokio::test]
async fn it_should_pass_bytes_up_to_be_read() {
let app = Router::new().route(
"/bytes",
post(|request: Request| async move {
let body_bytes = request
.into_body()
.collect()
.await
.expect("Should read body to bytes")
.to_bytes();
let body_text = String::from_utf8_lossy(&body_bytes);
format!("{}", body_text)
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/bytes")
.bytes("hello!".as_bytes().into())
.await
.text();
assert_eq!(text, "hello!");
}
#[tokio::test]
async fn it_should_not_change_content_type() {
let app = Router::new().route(
"/content_type",
post(|headers: HeaderMap| async move {
headers
.get(CONTENT_TYPE)
.map(|h| h.to_str().unwrap().to_string())
.unwrap_or_else(|| "".to_string())
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/content_type")
.content_type(&"application/testing")
.bytes("hello!".as_bytes().into())
.await
.text();
assert_eq!(text, "application/testing");
}
}
#[cfg(test)]
mod test_bytes_from_file {
use crate::TestServer;
use axum::extract::Request;
use axum::routing::post;
use axum::Router;
use http::header::CONTENT_TYPE;
use http::HeaderMap;
use http_body_util::BodyExt;
#[tokio::test]
async fn it_should_pass_bytes_up_to_be_read() {
let app = Router::new().route(
"/bytes",
post(|request: Request| async move {
let body_bytes = request
.into_body()
.collect()
.await
.expect("Should read body to bytes")
.to_bytes();
let body_text = String::from_utf8_lossy(&body_bytes);
format!("{}", body_text)
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/bytes")
.bytes_from_file(&"files/example.txt")
.await
.text();
assert_eq!(text, "hello!");
}
#[tokio::test]
async fn it_should_not_change_content_type() {
let app = Router::new().route(
"/content_type",
post(|headers: HeaderMap| async move {
headers
.get(CONTENT_TYPE)
.map(|h| h.to_str().unwrap().to_string())
.unwrap_or_else(|| "".to_string())
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/content_type")
.content_type(&"application/testing")
.bytes_from_file(&"files/example.txt")
.await
.text();
assert_eq!(text, "application/testing");
}
}
#[cfg(test)]
mod test_text {
use crate::TestServer;
use axum::extract::Request;
use axum::routing::post;
use axum::Router;
use http::header::CONTENT_TYPE;
use http::HeaderMap;
use http_body_util::BodyExt;
#[tokio::test]
async fn it_should_pass_text_up_to_be_read() {
let app = Router::new().route(
"/text",
post(|request: Request| async move {
let body_bytes = request
.into_body()
.collect()
.await
.expect("Should read body to bytes")
.to_bytes();
let body_text = String::from_utf8_lossy(&body_bytes);
format!("{}", body_text)
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server.post(&"/text").text(&"hello!").await.text();
assert_eq!(text, "hello!");
}
#[tokio::test]
async fn it_should_pass_text_content_type_for_text() {
let app = Router::new().route(
"/content_type",
post(|headers: HeaderMap| async move {
headers
.get(CONTENT_TYPE)
.map(|h| h.to_str().unwrap().to_string())
.unwrap_or_else(|| "".to_string())
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server.post(&"/content_type").text(&"hello!").await.text();
assert_eq!(text, "text/plain");
}
}
#[cfg(test)]
mod test_text_from_file {
use crate::TestServer;
use axum::extract::Request;
use axum::routing::post;
use axum::Router;
use http::header::CONTENT_TYPE;
use http::HeaderMap;
use http_body_util::BodyExt;
#[tokio::test]
async fn it_should_pass_text_up_to_be_read() {
let app = Router::new().route(
"/text",
post(|request: Request| async move {
let body_bytes = request
.into_body()
.collect()
.await
.expect("Should read body to bytes")
.to_bytes();
let body_text = String::from_utf8_lossy(&body_bytes);
format!("{}", body_text)
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/text")
.text_from_file(&"files/example.txt")
.await
.text();
assert_eq!(text, "hello!");
}
#[tokio::test]
async fn it_should_pass_text_content_type_for_text() {
let app = Router::new().route(
"/content_type",
post(|headers: HeaderMap| async move {
headers
.get(CONTENT_TYPE)
.map(|h| h.to_str().unwrap().to_string())
.unwrap_or_else(|| "".to_string())
}),
);
let server = TestServer::new(app).expect("Should create test server");
let text = server
.post(&"/content_type")
.text_from_file(&"files/example.txt")
.await
.text();
assert_eq!(text, "text/plain");
}
}
#[cfg(test)]
mod test_expect_success {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
use http::StatusCode;
#[tokio::test]
async fn it_should_not_panic_if_success_is_returned() {
async fn get_ping() -> &'static str {
"pong!"
}
let app = Router::new().route("/ping", get(get_ping));
let server = TestServer::new(app).expect("Should create test server");
server.get(&"/ping").expect_success().await;
}
#[tokio::test]
async fn it_should_not_panic_on_other_2xx_status_code() {
async fn get_accepted() -> StatusCode {
StatusCode::ACCEPTED
}
let app = Router::new().route("/accepted", get(get_accepted));
let server = TestServer::new(app).expect("Should create test server");
server.get(&"/accepted").expect_success().await;
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_on_404() {
let app = Router::new();
let server = TestServer::new(app).expect("Should create test server");
server.get(&"/some_unknown_route").expect_success().await;
}
#[tokio::test]
async fn it_should_override_what_test_server_has_set() {
async fn get_ping() -> &'static str {
"pong!"
}
let app = Router::new().route("/ping", get(get_ping));
let mut server = TestServer::new(app).expect("Should create test server");
server.expect_failure();
server.get(&"/ping").expect_success().await;
}
}
#[cfg(test)]
mod test_expect_failure {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
use http::StatusCode;
#[tokio::test]
async fn it_should_not_panic_if_expect_failure_on_404() {
let app = Router::new();
let server = TestServer::new(app).expect("Should create test server");
server.get(&"/some_unknown_route").expect_failure().await;
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_if_success_is_returned() {
async fn get_ping() -> &'static str {
"pong!"
}
let app = Router::new().route("/ping", get(get_ping));
let server = TestServer::new(app).expect("Should create test server");
server.get(&"/ping").expect_failure().await;
}
#[tokio::test]
#[should_panic]
async fn it_should_panic_on_other_2xx_status_code() {
async fn get_accepted() -> StatusCode {
StatusCode::ACCEPTED
}
let app = Router::new().route("/accepted", get(get_accepted));
let server = TestServer::new(app).expect("Should create test server");
server.get(&"/accepted").expect_failure().await;
}
#[tokio::test]
async fn it_should_should_override_what_test_server_has_set() {
let app = Router::new();
let mut server = TestServer::new(app).expect("Should create test server");
server.expect_success();
server.get(&"/some_unknown_route").expect_failure().await;
}
}
#[cfg(test)]
mod test_add_cookie {
use crate::TestServer;
use axum::routing::get;
use axum::Router;
use axum_extra::extract::cookie::CookieJar;
use cookie::time::Duration;
use cookie::time::OffsetDateTime;
use cookie::Cookie;
const TEST_COOKIE_NAME: &'static str = &"test-cookie";
async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
let cookie = cookies.get(&TEST_COOKIE_NAME);
let cookie_value = cookie
.map(|c| c.value().to_string())
.unwrap_or_else(|| "cookie-not-found".to_string());
(cookies, cookie_value)
}
#[tokio::test]
async fn it_should_send_cookies_added_to_request() {
let app = Router::new().route("/cookie", get(get_cookie));
let server = TestServer::new(app).expect("Should create test server");
let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
assert_eq!(response_text, "my-custom-cookie");
}
#[tokio::test]
async fn it_should_send_non_expired_cookies_added_to_request() {
let app = Router::new().route("/cookie", get(get_cookie));
let server = TestServer::new(app).expect("Should create test server");
let mut cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
cookie.set_expires(
OffsetDateTime::now_utc()
.checked_add(Duration::minutes(10))
.unwrap(),
);
let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
assert_eq!(response_text, "my-custom-cookie");
}
#[tokio::test]
async fn it_should_not_send_expired_cookies_added_to_request() {
let app = Router::new().route("/cookie", get(get_cookie));
let server = TestServer::new(app).expect("Should create test server");
let mut cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
cookie.set_expires(OffsetDateTime::now_utc());
let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
assert_eq!(response_text, "cookie-not-found");
}
}
#[cfg(test)]
mod test_add_cookies {
use crate::TestServer;
use axum::http::header::HeaderMap;
use axum::routing::get;
use axum::Router;
use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
use cookie::Cookie;
use cookie::CookieJar;
use cookie::SameSite;
async fn route_get_cookies(cookies: AxumCookieJar) -> String {
let mut all_cookies = cookies
.iter()
.map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
.collect::<Vec<String>>();
all_cookies.sort();
all_cookies.join(&", ")
}
async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
let cookies: String = headers
.get_all("cookie")
.into_iter()
.map(|c| c.to_str().unwrap_or("").to_string())
.reduce(|a, b| a + "; " + &b)
.unwrap_or_else(|| String::new());
cookies
}
#[tokio::test]
async fn it_should_send_all_cookies_added_by_jar() {
let app = Router::new().route("/cookies", get(route_get_cookies));
let server = TestServer::new(app).expect("Should create test server");
let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
let cookie_2 = Cookie::new("second-cookie", "other-cookie");
let mut cookie_jar = CookieJar::new();
cookie_jar.add(cookie_1);
cookie_jar.add(cookie_2);
server
.get(&"/cookies")
.add_cookies(cookie_jar)
.await
.assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
}
#[tokio::test]
async fn it_should_send_all_cookies_stripped_by_their_attributes() {
let app = Router::new().route("/cookies", get(get_cookie_headers_joined));
let server = TestServer::new(app).expect("Should create test server");
const TEST_COOKIE_NAME: &'static str = &"test-cookie";
const TEST_COOKIE_VALUE: &'static str = &"my-custom-cookie";
let cookie = Cookie::build((TEST_COOKIE_NAME, TEST_COOKIE_VALUE))
.http_only(true)
.secure(true)
.same_site(SameSite::Strict)
.path("/cookie")
.build();
let mut cookie_jar = CookieJar::new();
cookie_jar.add(cookie);
server
.get(&"/cookies")
.add_cookies(cookie_jar)
.await
.assert_text(format!("{}={}", TEST_COOKIE_NAME, TEST_COOKIE_VALUE));
}
}
#[cfg(test)]
mod test_save_cookies {
use crate::TestServer;
use axum::extract::Request;
use axum::http::header::HeaderMap;
use axum::routing::get;
use axum::routing::put;
use axum::Router;
use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
use cookie::Cookie;
use cookie::SameSite;
use http_body_util::BodyExt;
const TEST_COOKIE_NAME: &'static str = &"test-cookie";
async fn put_cookie_with_attributes(
mut cookies: AxumCookieJar,
request: Request,
) -> (AxumCookieJar, &'static str) {
let body_bytes = request
.into_body()
.collect()
.await
.expect("Should turn the body into bytes")
.to_bytes();
let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
.http_only(true)
.secure(true)
.same_site(SameSite::Strict)
.path("/cookie")
.build();
cookies = cookies.add(cookie);
(cookies, &"done")
}
async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
let cookies: String = headers
.get_all("cookie")
.into_iter()
.map(|c| c.to_str().unwrap_or("").to_string())
.reduce(|a, b| a + "; " + &b)
.unwrap_or_else(|| String::new());
cookies
}
#[tokio::test]
async fn it_should_strip_cookies_from_their_attributes() {
let app = Router::new()
.route("/cookie", put(put_cookie_with_attributes))
.route("/cookie", get(get_cookie_headers_joined));
let server = TestServer::new(app).expect("Should create test server");
server
.put(&"/cookie")
.text(&"cookie-found!")
.save_cookies()
.await;
let response_text = server.get(&"/cookie").await.text();
assert_eq!(response_text, format!("{}=cookie-found!", TEST_COOKIE_NAME));
}
}
#[cfg(test)]
mod test_do_not_save_cookies {
use crate::TestServer;
use axum::extract::Request;
use axum::http::header::HeaderMap;
use axum::routing::get;
use axum::routing::put;
use axum::Router;
use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
use cookie::Cookie;
use cookie::SameSite;
use http_body_util::BodyExt;
const TEST_COOKIE_NAME: &'static str = &"test-cookie";
async fn put_cookie_with_attributes(
mut cookies: AxumCookieJar,
request: Request,
) -> (AxumCookieJar, &'static str) {
let body_bytes = request
.into_body()
.collect()
.await
.expect("Should turn the body into bytes")
.to_bytes();
let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
.http_only(true)
.secure(true)
.same_site(SameSite::Strict)
.path("/cookie")
.build();
cookies = cookies.add(cookie);
(cookies, &"done")
}
async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
let cookies: String = headers
.get_all("cookie")
.into_iter()
.map(|c| c.to_str().unwrap_or("").to_string())
.reduce(|a, b| a + "; " + &b)
.unwrap_or_else(|| String::new());
cookies
}
#[tokio::test]
async fn it_should_not_save_cookies_when_set() {
let app = Router::new()
.route("/cookie", put(put_cookie_with_attributes))
.route("/cookie", get(get_cookie_headers_joined));
let server = TestServer::new(app).expect("Should create test server");
server
.put(&"/cookie")
.text(&"cookie-found!")
.do_not_save_cookies()
.await;
let response_text = server.get(&"/cookie").await.text();
assert_eq!(response_text, "");
}
#[tokio::test]
async fn it_should_override_test_server_and_not_save_cookies_when_set() {
let app = Router::new()
.route("/cookie", put(put_cookie_with_attributes))
.route("/cookie", get(get_cookie_headers_joined));
let server = TestServer::builder()
.save_cookies()
.build(app)
.expect("Should create test server");
server
.put(&"/cookie")
.text(&"cookie-found!")
.do_not_save_cookies()
.await;
let response_text = server.get(&"/cookie").await.text();
assert_eq!(response_text, "");
}
}
#[cfg(test)]
mod test_clear_cookies {
use crate::TestServer;
use axum::extract::Request;
use axum::routing::get;
use axum::routing::put;
use axum::Router;
use axum_extra::extract::cookie::Cookie as AxumCookie;
use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
use cookie::Cookie;
use cookie::CookieJar;
use http_body_util::BodyExt;
const TEST_COOKIE_NAME: &'static str = &"test-cookie";
async fn get_cookie(cookies: AxumCookieJar) -> (AxumCookieJar, String) {
let cookie = cookies.get(&TEST_COOKIE_NAME);
let cookie_value = cookie
.map(|c| c.value().to_string())
.unwrap_or_else(|| "cookie-not-found".to_string());
(cookies, cookie_value)
}
async fn put_cookie(
mut cookies: AxumCookieJar,
request: Request,
) -> (AxumCookieJar, &'static str) {
let body_bytes = request
.into_body()
.collect()
.await
.expect("Should turn the body into bytes")
.to_bytes();
let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
let cookie = AxumCookie::new(TEST_COOKIE_NAME, body_text);
cookies = cookies.add(cookie);
(cookies, &"done")
}
#[tokio::test]
async fn it_should_clear_cookie_added_to_request() {
let app = Router::new().route("/cookie", get(get_cookie));
let server = TestServer::new(app).expect("Should create test server");
let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
let response_text = server
.get(&"/cookie")
.add_cookie(cookie)
.clear_cookies()
.await
.text();
assert_eq!(response_text, "cookie-not-found");
}
#[tokio::test]
async fn it_should_clear_cookie_jar_added_to_request() {
let app = Router::new().route("/cookie", get(get_cookie));
let server = TestServer::new(app).expect("Should create test server");
let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
let mut cookie_jar = CookieJar::new();
cookie_jar.add(cookie);
let response_text = server
.get(&"/cookie")
.add_cookies(cookie_jar)
.clear_cookies()
.await
.text();
assert_eq!(response_text, "cookie-not-found");
}
#[tokio::test]
async fn it_should_clear_cookies_saved_by_past_request() {
let app = Router::new()
.route("/cookie", put(put_cookie))
.route("/cookie", get(get_cookie));
let server = TestServer::new(app).expect("Should create test server");
server
.put(&"/cookie")
.text(&"cookie-found!")
.save_cookies()
.await;
let response_text = server.get(&"/cookie").clear_cookies().await.text();
assert_eq!(response_text, "cookie-not-found");
}
#[tokio::test]
async fn it_should_clear_cookies_added_to_test_server() {
let app = Router::new()
.route("/cookie", put(put_cookie))
.route("/cookie", get(get_cookie));
let mut server = TestServer::new(app).expect("Should create test server");
let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
server.add_cookie(cookie);
let response_text = server.get(&"/cookie").clear_cookies().await.text();
assert_eq!(response_text, "cookie-not-found");
}
}
#[cfg(test)]
mod test_add_header {
use super::*;
use crate::TestServer;
use axum::async_trait;
use axum::extract::FromRequestParts;
use axum::routing::get;
use axum::Router;
use http::request::Parts;
use http::HeaderName;
use http::HeaderValue;
use hyper::StatusCode;
use std::marker::Sync;
const TEST_HEADER_NAME: &'static str = &"test-header";
const TEST_HEADER_CONTENT: &'static str = &"Test header content";
struct TestHeader(Vec<u8>);
#[async_trait]
impl<S: Sync> FromRequestParts<S> for TestHeader {
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<TestHeader, Self::Rejection> {
parts
.headers
.get(HeaderName::from_static(TEST_HEADER_NAME))
.map(|v| TestHeader(v.as_bytes().to_vec()))
.ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
}
}
async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
header
}
#[tokio::test]
async fn it_should_send_header_added_to_request() {
let app = Router::new().route("/header", get(ping_header));
let server = TestServer::new(app).expect("Should create test server");
let response = server
.get(&"/header")
.add_header(
HeaderName::from_static(TEST_HEADER_NAME),
HeaderValue::from_static(TEST_HEADER_CONTENT),
)
.await;
response.assert_text(TEST_HEADER_CONTENT)
}
}
#[cfg(test)]
mod test_authorization {
use super::*;
use crate::TestServer;
use axum::async_trait;
use axum::extract::FromRequestParts;
use axum::routing::get;
use axum::Router;
use http::request::Parts;
use hyper::StatusCode;
use std::marker::Sync;
fn new_test_server() -> TestServer {
struct TestHeader(String);
#[async_trait]
impl<S: Sync> FromRequestParts<S> for TestHeader {
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<TestHeader, Self::Rejection> {
parts
.headers
.get(header::AUTHORIZATION)
.map(|v| TestHeader(v.to_str().unwrap().to_string()))
.ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
}
}
async fn ping_auth_header(TestHeader(header): TestHeader) -> String {
header
}
let app = Router::new().route("/auth-header", get(ping_auth_header));
let mut server = TestServer::new(app).expect("Should create test server");
server.expect_success();
server
}
#[tokio::test]
async fn it_should_send_header_added_to_request() {
let server = new_test_server();
let response = server
.get(&"/auth-header")
.authorization("Bearer abc123")
.await;
response.assert_text("Bearer abc123")
}
}
#[cfg(test)]
mod test_authorization_bearer {
use super::*;
use crate::TestServer;
use axum::async_trait;
use axum::extract::FromRequestParts;
use axum::routing::get;
use axum::Router;
use http::request::Parts;
use hyper::StatusCode;
use std::marker::Sync;
fn new_test_server() -> TestServer {
struct TestHeader(String);
#[async_trait]
impl<S: Sync> FromRequestParts<S> for TestHeader {
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<TestHeader, Self::Rejection> {
parts
.headers
.get(header::AUTHORIZATION)
.map(|v| TestHeader(v.to_str().unwrap().to_string().replace("Bearer ", "")))
.ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
}
}
async fn ping_auth_header(TestHeader(header): TestHeader) -> String {
header
}
let app = Router::new().route("/auth-header", get(ping_auth_header));
let mut server = TestServer::new(app).expect("Should create test server");
server.expect_success();
server
}
#[tokio::test]
async fn it_should_send_header_added_to_request() {
let server = new_test_server();
let response = server
.get(&"/auth-header")
.authorization_bearer("abc123")
.await;
response.assert_text("abc123")
}
}
#[cfg(test)]
mod test_clear_headers {
use super::*;
use crate::TestServer;
use axum::async_trait;
use axum::extract::FromRequestParts;
use axum::routing::get;
use axum::Router;
use http::request::Parts;
use http::HeaderName;
use http::HeaderValue;
use hyper::StatusCode;
use std::marker::Sync;
const TEST_HEADER_NAME: &'static str = &"test-header";
const TEST_HEADER_CONTENT: &'static str = &"Test header content";
struct TestHeader(Vec<u8>);
#[async_trait]
impl<S: Sync> FromRequestParts<S> for TestHeader {
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<TestHeader, Self::Rejection> {
parts
.headers
.get(HeaderName::from_static(TEST_HEADER_NAME))
.map(|v| TestHeader(v.as_bytes().to_vec()))
.ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
}
}
async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
header
}
#[tokio::test]
async fn it_should_clear_headers_added_to_request() {
let app = Router::new().route("/header", get(ping_header));
let server = TestServer::new(app).expect("Should create test server");
let response = server
.get(&"/header")
.add_header(
HeaderName::from_static(TEST_HEADER_NAME),
HeaderValue::from_static(TEST_HEADER_CONTENT),
)
.clear_headers()
.await;
response.assert_status_bad_request();
response.assert_text("Missing test header");
}
#[tokio::test]
async fn it_should_clear_headers_added_to_server() {
let app = Router::new().route("/header", get(ping_header));
let mut server = TestServer::new(app).expect("Should create test server");
server.add_header(
HeaderName::from_static(TEST_HEADER_NAME),
HeaderValue::from_static(TEST_HEADER_CONTENT),
);
let response = server.get(&"/header").clear_headers().await;
response.assert_status_bad_request();
response.assert_text("Missing test header");
}
}
#[cfg(test)]
mod test_add_query_params {
use crate::TestServer;
use axum::extract::Query as AxumStdQuery;
use axum::routing::get;
use axum::Router;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
#[derive(Debug, Deserialize, Serialize)]
struct QueryParam {
message: String,
}
async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
params.message
}
#[derive(Debug, Deserialize, Serialize)]
struct QueryParam2 {
message: String,
other: String,
}
async fn get_query_param_2(AxumStdQuery(params): AxumStdQuery<QueryParam2>) -> String {
format!("{}-{}", params.message, params.other)
}
fn build_app() -> Router {
Router::new()
.route("/query", get(get_query_param))
.route("/query-2", get(get_query_param_2))
}
#[tokio::test]
async fn it_should_pass_up_query_params_from_serialization() {
let server = TestServer::new(build_app()).expect("Should create test server");
server
.get(&"/query")
.add_query_params(QueryParam {
message: "it works".to_string(),
})
.await
.assert_text(&"it works");
}
#[tokio::test]
async fn it_should_pass_up_query_params_from_pairs() {
let server = TestServer::new(build_app()).expect("Should create test server");
server
.get(&"/query")
.add_query_params(&[("message", "it works")])
.await
.assert_text(&"it works");
}
#[tokio::test]
async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
let server = TestServer::new(build_app()).expect("Should create test server");
server
.get(&"/query-2")
.add_query_params(&[("message", "it works"), ("other", "yup")])
.await
.assert_text(&"it works-yup");
}
#[tokio::test]
async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
let server = TestServer::new(build_app()).expect("Should create test server");
server
.get(&"/query-2")
.add_query_params(&[("message", "it works")])
.add_query_params(&[("other", "yup")])
.await
.assert_text(&"it works-yup");
}
#[tokio::test]
async fn it_should_pass_up_multiple_query_params_from_json() {
let server = TestServer::new(build_app()).expect("Should create test server");
server
.get(&"/query-2")
.add_query_params(json!({
"message": "it works",
"other": "yup"
}))
.await
.assert_text(&"it works-yup");
}
}
#[cfg(test)]
mod test_add_raw_query_param {
use crate::TestServer;
use axum::extract::Query as AxumStdQuery;
use axum::routing::get;
use axum::Router;
use axum_extra::extract::Query as AxumExtraQuery;
use serde::Deserialize;
use serde::Serialize;
use std::fmt::Write;
#[derive(Debug, Deserialize, Serialize)]
struct QueryParam {
message: String,
}
async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
params.message
}
#[derive(Debug, Deserialize, Serialize)]
struct QueryParamExtra {
#[serde(default)]
items: Vec<String>,
#[serde(default, rename = "arrs[]")]
arrs: Vec<String>,
}
async fn get_query_param_extra(
AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
) -> String {
let mut output = String::new();
if params.items.len() > 0 {
write!(output, "{}", params.items.join(", ")).unwrap();
}
if params.arrs.len() > 0 {
write!(output, "{}", params.arrs.join(", ")).unwrap();
}
output
}
fn build_app() -> Router {
Router::new()
.route("/query", get(get_query_param))
.route("/query-extra", get(get_query_param_extra))
}
#[tokio::test]
async fn it_should_pass_up_query_param_as_is() {
let server = TestServer::new(build_app()).expect("Should create test server");
server
.get(&"/query")
.add_raw_query_param(&"message=it-works")
.await
.assert_text(&"it-works");
}
#[tokio::test]
async fn it_should_pass_up_array_query_params_as_one_string() {
let server = TestServer::new(build_app()).expect("Should create test server");
server
.get(&"/query-extra")
.add_raw_query_param(&"items=one&items=two&items=three")
.await
.assert_text(&"one, two, three");
}
#[tokio::test]
async fn it_should_pass_up_array_query_params_as_multiple_params() {
let server = TestServer::new(build_app()).expect("Should create test server");
server
.get(&"/query-extra")
.add_raw_query_param(&"arrs[]=one")
.add_raw_query_param(&"arrs[]=two")
.add_raw_query_param(&"arrs[]=three")
.await
.assert_text(&"one, two, three");
}
}
#[cfg(test)]
mod test_add_query_param {
use crate::TestServer;
use axum::extract::Query;
use axum::routing::get;
use axum::Router;
use serde::Deserialize;
use serde::Serialize;
#[derive(Debug, Deserialize, Serialize)]
struct QueryParam {
message: String,
}
async fn get_query_param(Query(params): Query<QueryParam>) -> String {
params.message
}
#[derive(Debug, Deserialize, Serialize)]
struct QueryParam2 {
message: String,
other: String,
}
async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
format!("{}-{}", params.message, params.other)
}
#[tokio::test]
async fn it_should_pass_up_query_params_from_pairs() {
let app = Router::new().route("/query", get(get_query_param));
let server = TestServer::new(app).expect("Should create test server");
server
.get(&"/query")
.add_query_param("message", "it works")
.await
.assert_text(&"it works");
}
#[tokio::test]
async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
let app = Router::new().route("/query-2", get(get_query_param_2));
let server = TestServer::new(app).expect("Should create test server");
server
.get(&"/query-2")
.add_query_param("message", "it works")
.add_query_param("other", "yup")
.await
.assert_text(&"it works-yup");
}
}
#[cfg(test)]
mod test_clear_query_params {
use crate::TestServer;
use axum::extract::Query;
use axum::routing::get;
use axum::Router;
use serde::Deserialize;
use serde::Serialize;
#[derive(Debug, Deserialize, Serialize)]
struct QueryParams {
first: Option<String>,
second: Option<String>,
}
async fn get_query_params(Query(params): Query<QueryParams>) -> String {
format!(
"has first? {}, has second? {}",
params.first.is_some(),
params.second.is_some()
)
}
#[tokio::test]
async fn it_should_clear_all_params_set() {
let app = Router::new().route("/query", get(get_query_params));
let server = TestServer::new(app).expect("Should create test server");
server
.get(&"/query")
.add_query_params(QueryParams {
first: Some("first".to_string()),
second: Some("second".to_string()),
})
.clear_query_params()
.await
.assert_text(&"has first? false, has second? false");
}
#[tokio::test]
async fn it_should_clear_all_params_set_and_allow_replacement() {
let app = Router::new().route("/query", get(get_query_params));
let server = TestServer::new(app).expect("Should create test server");
server
.get(&"/query")
.add_query_params(QueryParams {
first: Some("first".to_string()),
second: Some("second".to_string()),
})
.clear_query_params()
.add_query_params(QueryParams {
first: Some("first".to_string()),
second: Some("second".to_string()),
})
.await
.assert_text(&"has first? true, has second? true");
}
}
#[cfg(test)]
mod test_scheme {
use crate::TestServer;
use axum::extract::Request;
use axum::routing::get;
use axum::Router;
async fn route_get_scheme(request: Request) -> String {
request.uri().scheme_str().unwrap().to_string()
}
#[tokio::test]
async fn it_should_return_http_by_default() {
let router = Router::new().route("/scheme", get(route_get_scheme));
let server = TestServer::builder().build(router).unwrap();
server.get("/scheme").await.assert_text("http");
}
#[tokio::test]
async fn it_should_return_http_when_set() {
let router = Router::new().route("/scheme", get(route_get_scheme));
let server = TestServer::builder().build(router).unwrap();
server
.get("/scheme")
.scheme(&"http")
.await
.assert_text("http");
}
#[tokio::test]
async fn it_should_return_https_when_set() {
let router = Router::new().route("/scheme", get(route_get_scheme));
let server = TestServer::builder().build(router).unwrap();
server
.get("/scheme")
.scheme(&"https")
.await
.assert_text("https");
}
#[tokio::test]
async fn it_should_override_test_server_when_set() {
let router = Router::new().route("/scheme", get(route_get_scheme));
let mut server = TestServer::builder().build(router).unwrap();
server.scheme(&"https");
server
.get("/scheme")
.scheme(&"http") .await
.assert_text("http");
}
}
#[cfg(test)]
mod test_multipart {
use crate::multipart::MultipartForm;
use crate::multipart::Part;
use crate::TestServer;
use axum::extract::Multipart;
use axum::routing::post;
use axum::Json;
use axum::Router;
async fn route_post_multipart(mut multipart: Multipart) -> Json<Vec<String>> {
let mut fields = vec![];
while let Some(field) = multipart.next_field().await.unwrap() {
let name = field.name().unwrap().to_string();
let content_type = field.content_type().unwrap().to_owned();
let data = field.bytes().await.unwrap();
let field_stats = format!("{name} is {} bytes, {content_type}", data.len());
fields.push(field_stats);
}
Json(fields)
}
fn test_router() -> Router {
Router::new().route("/multipart", post(route_post_multipart))
}
#[tokio::test]
async fn it_should_get_multipart_stats_on_mock_transport() {
let server = TestServer::builder()
.mock_transport()
.build(test_router())
.expect("Should create test server");
let form = MultipartForm::new()
.add_text("penguins?", "lots")
.add_text("animals", "🦊🦊🦊")
.add_text("carrots", 123 as u32);
server
.post(&"/multipart")
.multipart(form)
.await
.assert_json(&vec![
"penguins? is 4 bytes, text/plain".to_string(),
"animals is 12 bytes, text/plain".to_string(),
"carrots is 3 bytes, text/plain".to_string(),
]);
}
#[tokio::test]
async fn it_should_get_multipart_stats_on_http_transport() {
let server = TestServer::builder()
.http_transport()
.build(test_router())
.expect("Should create test server");
let form = MultipartForm::new()
.add_text("penguins?", "lots")
.add_text("animals", "🦊🦊🦊")
.add_text("carrots", 123 as u32);
server
.post(&"/multipart")
.multipart(form)
.await
.assert_json(&vec![
"penguins? is 4 bytes, text/plain".to_string(),
"animals is 12 bytes, text/plain".to_string(),
"carrots is 3 bytes, text/plain".to_string(),
]);
}
#[tokio::test]
async fn it_should_send_text_parts_as_text() {
let server = TestServer::builder()
.mock_transport()
.build(test_router())
.expect("Should create test server");
let form = MultipartForm::new().add_part("animals", Part::text("🦊🦊🦊"));
server
.post(&"/multipart")
.multipart(form)
.await
.assert_json(&vec!["animals is 12 bytes, text/plain".to_string()]);
}
#[tokio::test]
async fn it_should_send_custom_mime_type() {
let server = TestServer::builder()
.mock_transport()
.build(test_router())
.expect("Should create test server");
let form = MultipartForm::new().add_part(
"animals",
Part::bytes("🦊,🦊,🦊".as_bytes()).mime_type(mime::TEXT_CSV),
);
server
.post(&"/multipart")
.multipart(form)
.await
.assert_json(&vec!["animals is 14 bytes, text/csv".to_string()]);
}
#[tokio::test]
async fn it_should_send_using_include_bytes() {
let server = TestServer::builder()
.mock_transport()
.build(test_router())
.expect("Should create test server");
let form = MultipartForm::new().add_part(
"file",
Part::bytes(include_bytes!("../rust-toolchain").as_slice()).mime_type(mime::TEXT_PLAIN),
);
server
.post(&"/multipart")
.multipart(form)
.await
.assert_json(&vec!["file is 6 bytes, text/plain".to_string()]);
}
}