use super::*;
use http::header::HeaderName;
use http::HeaderValue;
use indexmap::IndexMap;
use once_cell::sync::Lazy;
use pingora_error::{Error, ErrorType};
use regex::bytes::Regex;
use std::num::IntErrorKind;
use std::slice;
use std::str;
pub const DELTA_SECONDS_OVERFLOW_VALUE: u32 = 2147483648;
pub type DirectiveKey = String;
#[derive(Debug)]
pub struct DirectiveValue(pub Vec<u8>);
impl AsRef<[u8]> for DirectiveValue {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl DirectiveValue {
pub fn parse_as_bytes(&self) -> &[u8] {
self.0
.strip_prefix(b"\"")
.and_then(|bytes| bytes.strip_suffix(b"\""))
.unwrap_or(&self.0[..])
}
pub fn parse_as_str(&self) -> Result<&str> {
str::from_utf8(self.parse_as_bytes()).or_else(|e| {
Error::e_because(ErrorType::InternalError, "could not parse value as utf8", e)
})
}
pub fn parse_as_delta_seconds(&self) -> Result<u32> {
match self.parse_as_str()?.parse::<u32>() {
Ok(value) => Ok(value),
Err(e) => {
if e.kind() == &IntErrorKind::PosOverflow {
Ok(DELTA_SECONDS_OVERFLOW_VALUE)
} else {
Error::e_because(ErrorType::InternalError, "could not parse value as u32", e)
}
}
}
}
}
pub type DirectiveMap = IndexMap<DirectiveKey, Option<DirectiveValue>>;
#[derive(Debug)]
pub struct CacheControl {
pub directives: DirectiveMap,
}
#[derive(Debug, PartialEq, Eq)]
pub enum Cacheable {
Yes,
No,
Default,
}
pub struct ListValueIter<'a>(slice::Split<'a, u8, fn(&u8) -> bool>);
impl<'a> ListValueIter<'a> {
pub fn from(value: &'a DirectiveValue) -> Self {
ListValueIter(value.parse_as_bytes().split(|byte| byte == &b','))
}
}
fn trim_ows(bytes: &[u8]) -> &[u8] {
fn not_ows(b: &u8) -> bool {
b != &b'\x20' && b != &b'\x09'
}
let head = bytes.iter().position(not_ows).unwrap_or(0);
let tail = bytes
.iter()
.rposition(not_ows)
.map(|rpos| rpos + 1)
.unwrap_or(head);
&bytes[head..tail]
}
impl<'a> Iterator for ListValueIter<'a> {
type Item = &'a [u8];
fn next(&mut self) -> Option<Self::Item> {
Some(trim_ows(self.0.next()?))
}
}
static RE_CACHE_DIRECTIVE: Lazy<Regex> =
Lazy::new(|| {
Regex::new(r#"(?-u)(?:^|(?:\s*[,;]\s*))([^\x00-\x20\(\)<>@,;:\\"/\[\]\?=\{\}\x7F]+)(?:=((?:[^\x00-\x20\(\)<>@,;:\\"/\[\]\?=\{\}\x7F]+|(?:"(?:[^"\\]|\\.)*"))))?"#).unwrap()
});
impl CacheControl {
fn from_headers(headers: http::header::GetAll<HeaderValue>) -> Option<Self> {
let mut directives = IndexMap::new();
for line in headers {
for captures in RE_CACHE_DIRECTIVE.captures_iter(line.as_bytes()) {
let key = captures.get(1).and_then(|cap| {
str::from_utf8(cap.as_bytes())
.ok()
.map(|token| token.to_lowercase())
});
if key.is_none() {
continue;
}
let value = captures
.get(2)
.map(|cap| DirectiveValue(cap.as_bytes().to_vec()));
directives.insert(key.unwrap(), value);
}
}
Some(CacheControl { directives })
}
pub fn from_headers_named(header_name: &str, headers: &http::HeaderMap) -> Option<Self> {
if !headers.contains_key(header_name) {
return None;
}
Self::from_headers(headers.get_all(header_name))
}
pub fn from_req_headers_named(header_name: &str, req_header: &ReqHeader) -> Option<Self> {
Self::from_headers_named(header_name, &req_header.headers)
}
pub fn from_req_headers(req_header: &ReqHeader) -> Option<Self> {
Self::from_req_headers_named("cache-control", req_header)
}
pub fn from_resp_headers_named(header_name: &str, resp_header: &RespHeader) -> Option<Self> {
Self::from_headers_named(header_name, &resp_header.headers)
}
pub fn from_resp_headers(resp_header: &RespHeader) -> Option<Self> {
Self::from_resp_headers_named("cache-control", resp_header)
}
pub fn has_key(&self, key: &str) -> bool {
self.directives.contains_key(key)
}
pub fn public(&self) -> bool {
self.has_key("public")
}
fn has_key_without_value(&self, key: &str) -> bool {
matches!(self.directives.get(key), Some(None))
}
pub fn private(&self) -> bool {
self.has_key_without_value("private")
}
fn get_field_names(&self, key: &str) -> Option<ListValueIter> {
if let Some(Some(value)) = self.directives.get(key) {
Some(ListValueIter::from(value))
} else {
None
}
}
pub fn private_field_names(&self) -> Option<ListValueIter> {
self.get_field_names("private")
}
pub fn no_cache(&self) -> bool {
self.has_key_without_value("no-cache")
}
pub fn no_cache_field_names(&self) -> Option<ListValueIter> {
self.get_field_names("no-cache")
}
pub fn no_store(&self) -> bool {
self.has_key("no-store")
}
fn parse_delta_seconds(&self, key: &str) -> Result<Option<u32>> {
if let Some(Some(dir_value)) = self.directives.get(key) {
Ok(Some(dir_value.parse_as_delta_seconds()?))
} else {
Ok(None)
}
}
pub fn max_age(&self) -> Result<Option<u32>> {
self.parse_delta_seconds("max-age")
}
pub fn s_maxage(&self) -> Result<Option<u32>> {
self.parse_delta_seconds("s-maxage")
}
pub fn stale_while_revalidate(&self) -> Result<Option<u32>> {
self.parse_delta_seconds("stale-while-revalidate")
}
pub fn stale_if_error(&self) -> Result<Option<u32>> {
self.parse_delta_seconds("stale-if-error")
}
pub fn must_revalidate(&self) -> bool {
self.has_key("must-revalidate")
}
pub fn proxy_revalidate(&self) -> bool {
self.has_key("proxy-revalidate")
}
pub fn only_if_cached(&self) -> bool {
self.has_key("only-if-cached")
}
}
impl InterpretCacheControl for CacheControl {
fn is_cacheable(&self) -> Cacheable {
if self.no_store() || self.private() {
return Cacheable::No;
}
if self.has_key("s-maxage") || self.has_key("max-age") || self.public() {
return Cacheable::Yes;
}
Cacheable::Default
}
fn allow_caching_authorized_req(&self) -> bool {
self.must_revalidate() || self.public() || self.has_key("s-maxage")
}
fn fresh_sec(&self) -> Option<u32> {
if self.no_cache() {
return Some(0);
}
match self.s_maxage() {
Ok(Some(seconds)) => Some(seconds),
Ok(None) => match self.max_age() {
Ok(Some(seconds)) => Some(seconds),
_ => None,
},
_ => None,
}
}
fn serve_stale_while_revalidate_sec(&self) -> Option<u32> {
if self.must_revalidate() || self.proxy_revalidate() || self.has_key("s-maxage") {
return Some(0);
}
self.stale_while_revalidate().unwrap_or(None)
}
fn serve_stale_if_error_sec(&self) -> Option<u32> {
if self.must_revalidate() || self.proxy_revalidate() || self.has_key("s-maxage") {
return Some(0);
}
self.stale_if_error().unwrap_or(None)
}
fn strip_private_headers(&self, resp_header: &mut ResponseHeader) {
fn strip_listed_headers(resp: &mut ResponseHeader, field_names: ListValueIter) {
for name in field_names {
if let Ok(header) = HeaderName::from_bytes(name) {
resp.remove_header(&header);
}
}
}
if let Some(headers) = self.private_field_names() {
strip_listed_headers(resp_header, headers);
}
if let Some(headers) = self.no_cache_field_names() {
strip_listed_headers(resp_header, headers);
}
}
}
pub trait InterpretCacheControl {
fn is_cacheable(&self) -> Cacheable;
fn allow_caching_authorized_req(&self) -> bool;
fn fresh_sec(&self) -> Option<u32>;
fn serve_stale_while_revalidate_sec(&self) -> Option<u32>;
fn serve_stale_if_error_sec(&self) -> Option<u32>;
fn strip_private_headers(&self, resp_header: &mut ResponseHeader);
}
#[cfg(test)]
mod tests {
use super::*;
use http::header::CACHE_CONTROL;
use http::{request, response};
fn build_response(cc_key: HeaderName, cc_value: &str) -> response::Parts {
let (parts, _) = response::Builder::new()
.header(cc_key, cc_value)
.body(())
.unwrap()
.into_parts();
parts
}
#[test]
fn test_simple_cache_control() {
let resp = build_response(CACHE_CONTROL, "public, max-age=10000");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.public());
assert_eq!(cc.max_age().unwrap().unwrap(), 10000);
}
#[test]
fn test_private_cache_control() {
let resp = build_response(CACHE_CONTROL, "private");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.private());
assert!(cc.max_age().unwrap().is_none());
}
#[test]
fn test_directives_across_header_lines() {
let (parts, _) = response::Builder::new()
.header(CACHE_CONTROL, "public,")
.header("cache-Control", "max-age=10000")
.body(())
.unwrap()
.into_parts();
let cc = CacheControl::from_resp_headers(&parts).unwrap();
assert!(cc.public());
assert_eq!(cc.max_age().unwrap().unwrap(), 10000);
}
#[test]
fn test_recognizes_semicolons_as_delimiters() {
let resp = build_response(CACHE_CONTROL, "public; max-age=0");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.public());
assert_eq!(cc.max_age().unwrap().unwrap(), 0);
}
#[test]
fn test_unknown_directives() {
let resp = build_response(CACHE_CONTROL, "public,random1=random2, rand3=\"\"");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
let mut directive_iter = cc.directives.iter();
let first = directive_iter.next().unwrap();
assert_eq!(first.0, &"public");
assert!(first.1.is_none());
let second = directive_iter.next().unwrap();
assert_eq!(second.0, &"random1");
assert_eq!(second.1.as_ref().unwrap().0, "random2".as_bytes());
let third = directive_iter.next().unwrap();
assert_eq!(third.0, &"rand3");
assert_eq!(third.1.as_ref().unwrap().0, "\"\"".as_bytes());
assert!(directive_iter.next().is_none());
}
#[test]
fn test_case_insensitive_directive_keys() {
let resp = build_response(
CACHE_CONTROL,
"Public=\"something\", mAx-AGe=\"10000\", foo=cRaZyCaSe, bAr=\"inQuotes\"",
);
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.public());
assert_eq!(cc.max_age().unwrap().unwrap(), 10000);
let mut directive_iter = cc.directives.iter();
let first = directive_iter.next().unwrap();
assert_eq!(first.0, &"public");
assert_eq!(first.1.as_ref().unwrap().0, "\"something\"".as_bytes());
let second = directive_iter.next().unwrap();
assert_eq!(second.0, &"max-age");
assert_eq!(second.1.as_ref().unwrap().0, "\"10000\"".as_bytes());
let third = directive_iter.next().unwrap();
assert_eq!(third.0, &"foo");
assert_eq!(third.1.as_ref().unwrap().0, "cRaZyCaSe".as_bytes());
let fourth = directive_iter.next().unwrap();
assert_eq!(fourth.0, &"bar");
assert_eq!(fourth.1.as_ref().unwrap().0, "\"inQuotes\"".as_bytes());
assert!(directive_iter.next().is_none());
}
#[test]
fn test_non_ascii() {
let resp = build_response(CACHE_CONTROL, "püblic=💖, max-age=\"💯\"");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(!cc.public());
assert_eq!(
cc.max_age().unwrap_err().context.unwrap().to_string(),
"could not parse value as u32"
);
let mut directive_iter = cc.directives.iter();
let first = directive_iter.next().unwrap();
assert_eq!(first.0, &"püblic");
assert_eq!(first.1.as_ref().unwrap().0, "💖".as_bytes());
let second = directive_iter.next().unwrap();
assert_eq!(second.0, &"max-age");
assert_eq!(second.1.as_ref().unwrap().0, "\"💯\"".as_bytes());
assert!(directive_iter.next().is_none());
}
#[test]
fn test_non_utf8_key() {
let mut resp = response::Builder::new().body(()).unwrap();
resp.headers_mut().insert(
CACHE_CONTROL,
HeaderValue::from_bytes(b"bar\xFF=\"baz\", a=b").unwrap(),
);
let (parts, _) = resp.into_parts();
let cc = CacheControl::from_resp_headers(&parts).unwrap();
let mut directive_iter = cc.directives.iter();
let first = directive_iter.next().unwrap();
assert_eq!(first.0, &"a");
assert_eq!(first.1.as_ref().unwrap().0, "b".as_bytes());
assert!(directive_iter.next().is_none());
}
#[test]
fn test_non_utf8_value() {
let mut resp = response::Builder::new().body(()).unwrap();
resp.headers_mut().insert(
CACHE_CONTROL,
HeaderValue::from_bytes(b"max-age=ba\xFFr, bar=\"baz\xFF\", a=b").unwrap(),
);
let (parts, _) = resp.into_parts();
let cc = CacheControl::from_resp_headers(&parts).unwrap();
assert_eq!(
cc.max_age().unwrap_err().context.unwrap().to_string(),
"could not parse value as utf8"
);
let mut directive_iter = cc.directives.iter();
let first = directive_iter.next().unwrap();
assert_eq!(first.0, &"max-age");
assert_eq!(first.1.as_ref().unwrap().0, b"ba\xFFr");
let second = directive_iter.next().unwrap();
assert_eq!(second.0, &"bar");
assert_eq!(second.1.as_ref().unwrap().0, b"\"baz\xFF\"");
let third = directive_iter.next().unwrap();
assert_eq!(third.0, &"a");
assert_eq!(third.1.as_ref().unwrap().0, "b".as_bytes());
assert!(directive_iter.next().is_none());
}
#[test]
fn test_age_overflow() {
let resp = build_response(
CACHE_CONTROL,
"max-age=-99999999999999999999999999, s-maxage=99999999999999999999999999",
);
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(
cc.s_maxage().unwrap().unwrap(),
DELTA_SECONDS_OVERFLOW_VALUE
);
assert_eq!(
cc.max_age().unwrap_err().context.unwrap().to_string(),
"could not parse value as u32"
);
}
#[test]
fn test_fresh_sec() {
let resp = build_response(CACHE_CONTROL, "");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.fresh_sec().is_none());
let resp = build_response(CACHE_CONTROL, "max-age=12345");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.fresh_sec().unwrap(), 12345);
let resp = build_response(CACHE_CONTROL, "max-age=99999,s-maxage=123");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.fresh_sec().unwrap(), 123);
}
#[test]
fn test_cacheability() {
let resp = build_response(CACHE_CONTROL, "");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.is_cacheable(), Cacheable::Default);
let resp = build_response(CACHE_CONTROL, "private, max-age=12345");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.is_cacheable(), Cacheable::No);
let resp = build_response(CACHE_CONTROL, "no-store, max-age=12345");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.is_cacheable(), Cacheable::No);
let resp = build_response(CACHE_CONTROL, "public");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.is_cacheable(), Cacheable::Yes);
let resp = build_response(CACHE_CONTROL, "max-age=0");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.is_cacheable(), Cacheable::Yes);
}
#[test]
fn test_no_cache() {
let resp = build_response(CACHE_CONTROL, "no-cache, max-age=12345");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.is_cacheable(), Cacheable::Yes);
assert_eq!(cc.fresh_sec().unwrap(), 0);
}
#[test]
fn test_no_cache_field_names() {
let resp = build_response(CACHE_CONTROL, "no-cache=\"set-cookie\", max-age=12345");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(!cc.private());
assert_eq!(cc.is_cacheable(), Cacheable::Yes);
assert_eq!(cc.fresh_sec().unwrap(), 12345);
let mut field_names = cc.no_cache_field_names().unwrap();
assert_eq!(
str::from_utf8(field_names.next().unwrap()).unwrap(),
"set-cookie"
);
assert!(field_names.next().is_none());
let mut resp = response::Builder::new().body(()).unwrap();
resp.headers_mut().insert(
CACHE_CONTROL,
HeaderValue::from_bytes(
b"private=\"\", no-cache=\"a\xFF, set-cookie, Baz\x09 , c,d ,, \"",
)
.unwrap(),
);
let (parts, _) = resp.into_parts();
let cc = CacheControl::from_resp_headers(&parts).unwrap();
let mut field_names = cc.private_field_names().unwrap();
assert_eq!(str::from_utf8(field_names.next().unwrap()).unwrap(), "");
assert!(field_names.next().is_none());
let mut field_names = cc.no_cache_field_names().unwrap();
assert!(str::from_utf8(field_names.next().unwrap()).is_err());
assert_eq!(
str::from_utf8(field_names.next().unwrap()).unwrap(),
"set-cookie"
);
assert_eq!(str::from_utf8(field_names.next().unwrap()).unwrap(), "Baz");
assert_eq!(str::from_utf8(field_names.next().unwrap()).unwrap(), "c");
assert_eq!(str::from_utf8(field_names.next().unwrap()).unwrap(), "d");
assert_eq!(str::from_utf8(field_names.next().unwrap()).unwrap(), "");
assert_eq!(str::from_utf8(field_names.next().unwrap()).unwrap(), "");
assert!(field_names.next().is_none());
}
#[test]
fn test_strip_private_headers() {
let mut resp = ResponseHeader::build(200, None).unwrap();
resp.append_header(
CACHE_CONTROL,
"no-cache=\"x-private-header\", max-age=12345",
)
.unwrap();
resp.append_header("X-Private-Header", "dropped").unwrap();
let cc = CacheControl::from_resp_headers(&resp).unwrap();
cc.strip_private_headers(&mut resp);
assert!(!resp.headers.contains_key("X-Private-Header"));
}
#[test]
fn test_stale_while_revalidate() {
let resp = build_response(CACHE_CONTROL, "max-age=12345, stale-while-revalidate=5");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.stale_while_revalidate().unwrap().unwrap(), 5);
assert_eq!(cc.serve_stale_while_revalidate_sec().unwrap(), 5);
assert!(cc.serve_stale_if_error_sec().is_none());
}
#[test]
fn test_stale_if_error() {
let resp = build_response(CACHE_CONTROL, "max-age=12345, stale-if-error=3600");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.stale_if_error().unwrap().unwrap(), 3600);
assert_eq!(cc.serve_stale_if_error_sec().unwrap(), 3600);
assert!(cc.serve_stale_while_revalidate_sec().is_none());
}
#[test]
fn test_must_revalidate() {
let resp = build_response(
CACHE_CONTROL,
"max-age=12345, stale-while-revalidate=60, stale-if-error=30, must-revalidate",
);
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.must_revalidate());
assert_eq!(cc.stale_while_revalidate().unwrap().unwrap(), 60);
assert_eq!(cc.stale_if_error().unwrap().unwrap(), 30);
assert_eq!(cc.serve_stale_while_revalidate_sec().unwrap(), 0);
assert_eq!(cc.serve_stale_if_error_sec().unwrap(), 0);
}
#[test]
fn test_proxy_revalidate() {
let resp = build_response(
CACHE_CONTROL,
"max-age=12345, stale-while-revalidate=60, stale-if-error=30, proxy-revalidate",
);
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.proxy_revalidate());
assert_eq!(cc.stale_while_revalidate().unwrap().unwrap(), 60);
assert_eq!(cc.stale_if_error().unwrap().unwrap(), 30);
assert_eq!(cc.serve_stale_while_revalidate_sec().unwrap(), 0);
assert_eq!(cc.serve_stale_if_error_sec().unwrap(), 0);
}
#[test]
fn test_s_maxage_stale() {
let resp = build_response(
CACHE_CONTROL,
"s-maxage=0, stale-while-revalidate=60, stale-if-error=30",
);
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.stale_while_revalidate().unwrap().unwrap(), 60);
assert_eq!(cc.stale_if_error().unwrap().unwrap(), 30);
assert_eq!(cc.serve_stale_while_revalidate_sec().unwrap(), 0);
assert_eq!(cc.serve_stale_if_error_sec().unwrap(), 0);
}
#[test]
fn test_authorized_request() {
let resp = build_response(CACHE_CONTROL, "max-age=10");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(!cc.allow_caching_authorized_req());
let resp = build_response(CACHE_CONTROL, "s-maxage=10");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.allow_caching_authorized_req());
let resp = build_response(CACHE_CONTROL, "public");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.allow_caching_authorized_req());
let resp = build_response(CACHE_CONTROL, "must-revalidate, max-age=0");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.allow_caching_authorized_req());
let resp = build_response(CACHE_CONTROL, "");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(!cc.allow_caching_authorized_req());
}
fn build_request(cc_key: HeaderName, cc_value: &str) -> request::Parts {
let (parts, _) = request::Builder::new()
.header(cc_key, cc_value)
.body(())
.unwrap()
.into_parts();
parts
}
#[test]
fn test_request_only_if_cached() {
let req = build_request(CACHE_CONTROL, "only-if-cached=1");
let cc = CacheControl::from_req_headers(&req).unwrap();
assert!(cc.only_if_cached())
}
}