use crate::{
error::{self, Error},
token::Token,
};
mod jwt;
use jwt::{Algorithm, Header, Key};
pub mod prelude {
pub use super::{RequestReason, ServiceAccountAccess, ServiceAccountInfo, TokenOrRequest};
}
const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer";
#[derive(serde::Deserialize, Debug, Clone)]
pub struct ServiceAccountInfo {
pub private_key: String,
pub client_email: String,
pub token_uri: String,
}
impl ServiceAccountInfo {
pub fn deserialize<T>(key_data: T) -> Result<Self, Error>
where
T: AsRef<[u8]>,
{
let slice = key_data.as_ref();
let account_info: Self = serde_json::from_slice(slice)?;
Ok(account_info)
}
}
struct Entry {
hash: u64,
token: Token,
}
#[derive(Debug)]
pub enum RequestReason {
Expired,
ScopesChanged,
}
#[derive(Debug)]
pub enum TokenOrRequest {
Token(Token),
Request {
request: http::Request<Vec<u8>>,
reason: RequestReason,
scope_hash: u64,
},
}
pub struct ServiceAccountAccess {
info: ServiceAccountInfo,
priv_key: Vec<u8>,
cache: parking_lot::Mutex<Vec<Entry>>,
}
impl ServiceAccountAccess {
pub fn new(info: ServiceAccountInfo) -> Result<Self, Error> {
let key_string = info
.private_key
.splitn(5, "-----")
.nth(2)
.ok_or(Error::InvalidKeyFormat)?;
let key_string = key_string.split_whitespace().fold(
String::with_capacity(key_string.len()),
|mut s, line| {
s.push_str(line);
s
},
);
let key_bytes = base64::decode_config(key_string.as_bytes(), base64::STANDARD)?;
Ok(Self {
info,
cache: parking_lot::Mutex::new(Vec::new()),
priv_key: key_bytes,
})
}
pub fn get_account_info(&self) -> &ServiceAccountInfo {
&self.info
}
pub fn get_token<'a, S, I>(&self, scopes: I) -> Result<TokenOrRequest, Error>
where
S: AsRef<str> + 'a,
I: IntoIterator<Item = &'a S>,
{
let (hash, scopes) = Self::serialize_scopes(scopes.into_iter());
let reason = {
let cache = self.cache.lock();
match cache.binary_search_by(|i| i.hash.cmp(&hash)) {
Ok(i) => {
let token = &cache[i].token;
if !token.has_expired() {
return Ok(TokenOrRequest::Token(token.clone()));
}
RequestReason::Expired
}
Err(_) => RequestReason::ScopesChanged,
}
};
let issued = chrono::Utc::now().timestamp();
let expiry = issued + 3600 - 5;
let claims = jwt::Claims {
issuer: self.info.client_email.clone(),
scope: scopes,
audience: self.info.token_uri.clone(),
expiration: expiry,
issued_at: issued,
sub: None,
};
let assertion = jwt::encode(
&Header::new(Algorithm::RS256),
&claims,
Key::Pkcs8(&self.priv_key),
)?;
let body = url::form_urlencoded::Serializer::new(String::new())
.append_pair("grant_type", GRANT_TYPE)
.append_pair("assertion", &assertion)
.finish();
let body = Vec::from(body);
let request = http::Request::builder()
.method("POST")
.uri(&self.info.token_uri)
.header(
http::header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
)
.header(http::header::CONTENT_LENGTH, body.len())
.body(body)?;
Ok(TokenOrRequest::Request {
reason,
request,
scope_hash: hash,
})
}
pub fn parse_token_response<S>(
&self,
hash: u64,
response: http::Response<S>,
) -> Result<Token, Error>
where
S: AsRef<[u8]>,
{
let (parts, body) = response.into_parts();
if !parts.status.is_success() {
let body_bytes = body.as_ref();
if parts
.headers
.get(http::header::CONTENT_TYPE)
.and_then(|ct| ct.to_str().ok())
== Some("application/json; charset=utf-8")
{
if let Ok(auth_error) = serde_json::from_slice::<error::AuthError>(body_bytes) {
return Err(Error::AuthError(auth_error));
}
}
return Err(Error::HttpStatus(parts.status));
}
let token_res: TokenResponse = serde_json::from_slice(body.as_ref())?;
let token: Token = token_res.into();
{
let mut cache = self.cache.lock();
match cache.binary_search_by(|i| i.hash.cmp(&hash)) {
Ok(i) => cache[i].token = token.clone(),
Err(i) => {
cache.insert(
i,
Entry {
hash,
token: token.clone(),
},
);
}
};
}
Ok(token)
}
fn serialize_scopes<'a, I, S>(scopes: I) -> (u64, String)
where
S: AsRef<str> + 'a,
I: Iterator<Item = &'a S>,
{
use std::hash::Hasher;
let scopes = scopes.map(|s| s.as_ref()).collect::<Vec<&str>>().join(" ");
let hash = {
let mut hasher = twox_hash::XxHash::default();
hasher.write(scopes.as_bytes());
hasher.finish()
};
(hash, scopes)
}
}
#[derive(serde::Deserialize, Debug)]
struct TokenResponse {
access_token: String,
token_type: String,
expires_in: i64,
}
impl Into<Token> for TokenResponse {
fn into(self) -> Token {
let expires_ts = chrono::Utc::now().timestamp() + self.expires_in;
Token {
access_token: self.access_token,
token_type: self.token_type,
refresh_token: String::new(),
expires_in: Some(self.expires_in),
expires_in_timestamp: Some(expires_ts),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn hash_scopes() {
use std::hash::Hasher;
let expected = {
let mut hasher = twox_hash::XxHash::default();
hasher.write(b"scope1 ");
hasher.write(b"scope2 ");
hasher.write(b"scope3");
hasher.finish()
};
let (hash, scopes) =
ServiceAccountAccess::serialize_scopes(["scope1", "scope2", "scope3"].iter());
assert_eq!(expected, hash);
assert_eq!("scope1 scope2 scope3", scopes);
let (hash, scopes) = ServiceAccountAccess::serialize_scopes(
vec![
"scope1".to_owned(),
"scope2".to_owned(),
"scope3".to_owned(),
]
.iter(),
);
assert_eq!(expected, hash);
assert_eq!("scope1 scope2 scope3", scopes);
}
}