pub use crate::types::TokenInfo;
use futures::lock::Mutex;
use std::collections::HashMap;
use std::io;
use std::path::{Path, PathBuf};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
struct ScopeHash(u64);
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
struct ScopeFilter(u64);
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum FilterResponse {
Maybe,
No,
}
impl ScopeFilter {
fn is_subset_of(self, filter: ScopeFilter) -> FilterResponse {
if self.0 & filter.0 == self.0 {
FilterResponse::Maybe
} else {
FilterResponse::No
}
}
}
#[derive(Debug)]
pub(crate) struct ScopeSet<'a, T> {
hash: ScopeHash,
filter: ScopeFilter,
scopes: &'a [T],
}
impl<'a, T> Clone for ScopeSet<'a, T> {
fn clone(&self) -> Self {
*self
}
}
impl<'a, T> Copy for ScopeSet<'a, T> {}
impl<'a, T> ScopeSet<'a, T>
where
T: AsRef<str>,
{
pub fn from(scopes: &'a [T]) -> Self {
let (hash, filter) = scopes.iter().fold(
(ScopeHash(0), ScopeFilter(0)),
|(mut scope_hash, mut scope_filter), scope| {
let h = seahash::hash(scope.as_ref().as_bytes());
for i in 0..4 {
let h = (h >> (6 * i)) & 0b11_1111;
scope_filter.0 |= 1 << h;
}
scope_hash.0 ^= h;
(scope_hash, scope_filter)
},
);
ScopeSet {
hash,
filter,
scopes,
}
}
}
#[async_trait]
pub trait TokenStorage: Send + Sync {
async fn set(&self, scopes: &[&str], token: TokenInfo) -> anyhow::Result<()>;
async fn get(&self, scopes: &[&str]) -> Option<TokenInfo>;
}
pub(crate) enum Storage {
Memory { tokens: Mutex<JSONTokens> },
Disk(DiskStorage),
Custom(Box<dyn TokenStorage>),
}
impl Storage {
pub(crate) async fn set<T>(
&self,
scopes: ScopeSet<'_, T>,
token: TokenInfo,
) -> anyhow::Result<()>
where
T: AsRef<str>,
{
match self {
Storage::Memory { tokens } => Ok(tokens.lock().await.set(scopes, token)?),
Storage::Disk(disk_storage) => Ok(disk_storage.set(scopes, token).await?),
Storage::Custom(custom_storage) => {
let mut str_scopes = scopes
.scopes
.iter()
.map(|scope| scope.as_ref())
.collect::<Vec<_>>();
str_scopes.sort_unstable();
str_scopes.dedup();
custom_storage.set(&str_scopes[..], token).await
}
}
}
pub(crate) async fn get<T>(&self, scopes: ScopeSet<'_, T>) -> Option<TokenInfo>
where
T: AsRef<str>,
{
match self {
Storage::Memory { tokens } => tokens.lock().await.get(scopes),
Storage::Disk(disk_storage) => disk_storage.get(scopes).await,
Storage::Custom(custom_storage) => {
let mut str_scopes = scopes
.scopes
.iter()
.map(|scope| scope.as_ref())
.collect::<Vec<_>>();
str_scopes.sort_unstable();
str_scopes.dedup();
custom_storage.get(&str_scopes[..]).await
}
}
}
}
#[derive(Debug, Clone)]
struct JSONToken {
scopes: Vec<String>,
token: TokenInfo,
hash: ScopeHash,
filter: ScopeFilter,
}
impl<'de> Deserialize<'de> for JSONToken {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
struct RawJSONToken {
scopes: Vec<String>,
token: TokenInfo,
}
let RawJSONToken { scopes, token } = RawJSONToken::deserialize(deserializer)?;
let ScopeSet { hash, filter, .. } = ScopeSet::from(&scopes);
Ok(JSONToken {
scopes,
token,
hash,
filter,
})
}
}
impl Serialize for JSONToken {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
#[derive(Serialize)]
struct RawJSONToken<'a> {
scopes: &'a [String],
token: &'a TokenInfo,
}
RawJSONToken {
scopes: &self.scopes,
token: &self.token,
}
.serialize(serializer)
}
}
#[derive(Debug, Clone)]
pub(crate) struct JSONTokens {
token_map: HashMap<ScopeHash, JSONToken>,
}
impl Serialize for JSONTokens {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.collect_seq(self.token_map.values())
}
}
impl<'de> Deserialize<'de> for JSONTokens {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct V;
impl<'de> serde::de::Visitor<'de> for V {
type Value = JSONTokens;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a sequence of JSONToken's")
}
fn visit_seq<M>(self, mut access: M) -> Result<Self::Value, M::Error>
where
M: serde::de::SeqAccess<'de>,
{
let mut token_map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
while let Some(json_token) = access.next_element::<JSONToken>()? {
token_map.insert(json_token.hash, json_token);
}
Ok(JSONTokens { token_map })
}
}
deserializer.deserialize_seq(V)
}
}
impl JSONTokens {
pub(crate) fn new() -> Self {
JSONTokens {
token_map: HashMap::new(),
}
}
async fn load_from_file(filename: &Path) -> Result<Self, io::Error> {
let contents = tokio::fs::read(filename).await?;
serde_json::from_slice(&contents).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
fn get<T>(
&self,
ScopeSet {
hash,
filter,
scopes,
}: ScopeSet<T>,
) -> Option<TokenInfo>
where
T: AsRef<str>,
{
if let Some(json_token) = self.token_map.get(&hash) {
return Some(json_token.token.clone());
}
let requested_scopes_are_subset_of = |other_scopes: &[String]| {
scopes
.iter()
.all(|s| other_scopes.iter().any(|t| t.as_str() == s.as_ref()))
};
self.token_map
.values()
.filter(|json_token| filter.is_subset_of(json_token.filter) == FilterResponse::Maybe)
.find(|v: &&JSONToken| requested_scopes_are_subset_of(&v.scopes))
.map(|t: &JSONToken| t.token.clone())
}
fn set<T>(
&mut self,
ScopeSet {
hash,
filter,
scopes,
}: ScopeSet<T>,
token: TokenInfo,
) -> Result<(), io::Error>
where
T: AsRef<str>,
{
use std::collections::hash_map::Entry;
match self.token_map.entry(hash) {
Entry::Occupied(mut entry) => {
entry.get_mut().token = token;
}
Entry::Vacant(entry) => {
let json_token = JSONToken {
scopes: scopes.iter().map(|x| x.as_ref().to_owned()).collect(),
token,
hash,
filter,
};
entry.insert(json_token);
}
}
Ok(())
}
}
pub(crate) struct DiskStorage {
tokens: Mutex<JSONTokens>,
filename: PathBuf,
}
impl DiskStorage {
pub(crate) async fn new(filename: PathBuf) -> Result<Self, io::Error> {
let tokens = match JSONTokens::load_from_file(&filename).await {
Ok(tokens) => tokens,
Err(e) if e.kind() == io::ErrorKind::NotFound => JSONTokens::new(),
Err(e) => return Err(e),
};
Ok(DiskStorage {
tokens: Mutex::new(tokens),
filename,
})
}
pub(crate) async fn set<T>(
&self,
scopes: ScopeSet<'_, T>,
token: TokenInfo,
) -> Result<(), io::Error>
where
T: AsRef<str>,
{
use tokio::io::AsyncWriteExt;
let json = {
use std::ops::Deref;
let mut lock = self.tokens.lock().await;
lock.set(scopes, token)?;
serde_json::to_string(lock.deref())
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
};
let mut f = open_writeable_file(&self.filename).await?;
f.write_all(json.as_bytes()).await?;
Ok(())
}
pub(crate) async fn get<T>(&self, scopes: ScopeSet<'_, T>) -> Option<TokenInfo>
where
T: AsRef<str>,
{
self.tokens.lock().await.get(scopes)
}
}
#[cfg(unix)]
async fn open_writeable_file(
filename: impl AsRef<Path>,
) -> Result<tokio::fs::File, tokio::io::Error> {
use std::os::unix::fs::OpenOptionsExt;
let opts: tokio::fs::OpenOptions = {
let mut opts = std::fs::OpenOptions::new();
opts.write(true).create(true).truncate(true).mode(0o600);
opts.into()
};
opts.open(filename).await
}
#[cfg(not(unix))]
async fn open_writeable_file(
filename: impl AsRef<Path>,
) -> Result<tokio::fs::File, tokio::io::Error> {
tokio::fs::File::create(filename).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scope_filter() {
let foo = ScopeSet::from(&["foo"]).filter;
let bar = ScopeSet::from(&["bar"]).filter;
let foobar = ScopeSet::from(&["foo", "bar"]).filter;
assert!(foo.is_subset_of(foobar) == FilterResponse::Maybe);
assert!(bar.is_subset_of(foobar) == FilterResponse::Maybe);
assert!(foo.is_subset_of(bar) == FilterResponse::No);
assert!(bar.is_subset_of(foo) == FilterResponse::No);
assert!(foobar.is_subset_of(foo) == FilterResponse::No);
assert!(foobar.is_subset_of(bar) == FilterResponse::No);
}
#[tokio::test]
async fn test_disk_storage() {
let new_token = |access_token: &str| TokenInfo {
access_token: Some(access_token.to_owned()),
refresh_token: None,
expires_at: None,
id_token: None,
};
let scope_set = ScopeSet::from(&["myscope"]);
let tempdir = tempfile::Builder::new()
.prefix("yup-oauth2-tests_")
.rand_bytes(15)
.tempdir()
.unwrap();
{
let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json"))
.await
.unwrap();
assert!(storage.get(scope_set).await.is_none());
storage
.set(scope_set, new_token("my_access_token"))
.await
.unwrap();
assert_eq!(
storage.get(scope_set).await,
Some(new_token("my_access_token"))
);
}
{
let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json"))
.await
.unwrap();
assert_eq!(
storage.get(scope_set).await,
Some(new_token("my_access_token"))
);
}
}
}