use std::{
collections::HashMap,
fmt::{self, Display, Formatter},
str::FromStr,
};
use tokio::io::AsyncBufRead;
use wasmparser::Payload;
mod cgi;
mod wcgi;
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub enum CgiDialect {
#[default]
Rfc3875,
Wcgi,
}
impl CgiDialect {
pub const CUSTOM_SECTION_NAME: &'static str = "cgi-dialect";
pub fn from_wasm(wasm: &[u8]) -> Option<CgiDialect> {
let dialect_sections = wasmparser::Parser::new(0)
.parse_all(wasm)
.filter_map(|p| match p {
Ok(Payload::CustomSection(custom))
if custom.name() == CgiDialect::CUSTOM_SECTION_NAME =>
{
Some(custom.data())
}
_ => None,
});
for data in dialect_sections {
let dialect = std::str::from_utf8(data).ok().and_then(|s| s.parse().ok());
if let Some(dialect) = dialect {
return Some(dialect);
}
}
None
}
pub fn prepare_environment_variables(
self,
parts: http::request::Parts,
env: &mut HashMap<String, String>,
) {
match self {
CgiDialect::Rfc3875 => cgi::prepare_environment_variables(parts, env),
CgiDialect::Wcgi => wcgi::prepare_environment_variables(parts, env),
}
}
pub async fn extract_response_header(
self,
stdout: &mut (impl AsyncBufRead + Unpin),
) -> Result<http::response::Parts, CgiError> {
match self {
CgiDialect::Rfc3875 => cgi::extract_response_header(stdout).await,
CgiDialect::Wcgi => wcgi::extract_response_header(stdout).await,
}
}
pub const fn to_str(self) -> &'static str {
match self {
CgiDialect::Rfc3875 => "rfc-3875",
CgiDialect::Wcgi => "wcgi",
}
}
}
impl FromStr for CgiDialect {
type Err = UnknownCgiDialect;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"rfc-3875" => Ok(CgiDialect::Rfc3875),
"wcgi" => Ok(CgiDialect::Wcgi),
_ => Err(UnknownCgiDialect),
}
}
}
impl Display for CgiDialect {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_str())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UnknownCgiDialect;
impl Display for UnknownCgiDialect {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Unknown CGI dialect")
}
}
impl std::error::Error for UnknownCgiDialect {}
#[derive(Debug)]
pub enum CgiError {
StdoutRead(std::io::Error),
InvalidHeaders {
error: http::Error,
header: String,
value: String,
},
MalformedWcgiHeader {
error: ::wcgi::WcgiError,
header: String,
},
}
impl std::error::Error for CgiError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
CgiError::StdoutRead(e) => Some(e),
CgiError::InvalidHeaders { error, .. } => Some(error),
CgiError::MalformedWcgiHeader { error, .. } => error.source(),
}
}
}
impl fmt::Display for CgiError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
CgiError::StdoutRead(_) => write!(f, "Unable to read the STDOUT pipe"),
CgiError::InvalidHeaders { header, value, .. } => {
write!(f, "Unable to parse header ({header}: {value})")
}
CgiError::MalformedWcgiHeader { header, .. } => {
write!(f, "Unable to parse WCGI header ({header})")
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_cgi_dialect_to_string() {
let dialects = [CgiDialect::Rfc3875, CgiDialect::Wcgi];
for dialect in dialects {
let repr = dialect.to_string();
let round_tripped: CgiDialect = repr.parse().unwrap();
assert_eq!(round_tripped, dialect);
}
}
}