#![cfg(feature = "service-account")]
use crate::error::Error;
use crate::types::TokenInfo;
use std::{io, path::PathBuf};
use base64::Engine as _;
use http::header;
use http_body_util::BodyExt;
use hyper_util::client::legacy::connect::Connect;
#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
use rustls::crypto::aws_lc_rs as crypto_provider;
#[cfg(feature = "ring")]
use rustls::crypto::ring as crypto_provider;
use rustls::{self, pki_types::PrivateKeyDer, sign::SigningKey};
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use url::form_urlencoded;
const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer";
const GOOGLE_RS256_HEAD: &str = r#"{"alg":"RS256","typ":"JWT"}"#;
fn append_base64<T: AsRef<[u8]> + ?Sized>(s: &T, out: &mut String) {
base64::engine::general_purpose::URL_SAFE.encode_string(s, out)
}
fn decode_rsa_key(pem_pkcs8: &str) -> Result<PrivateKeyDer, io::Error> {
let private_key = rustls_pemfile::pkcs8_private_keys(&mut pem_pkcs8.as_bytes()).next();
match private_key {
Some(Ok(key)) => Ok(PrivateKeyDer::Pkcs8(key)),
None => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Not enough private keys in PEM",
)),
Some(Err(_)) => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Error reading key from PEM",
)),
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ServiceAccountKey {
#[serde(rename = "type")]
pub key_type: Option<String>,
pub project_id: Option<String>,
pub private_key_id: Option<String>,
pub private_key: String,
pub client_email: String,
pub client_id: Option<String>,
pub auth_uri: Option<String>,
pub token_uri: String,
pub auth_provider_x509_cert_url: Option<String>,
pub client_x509_cert_url: Option<String>,
}
#[derive(Serialize, Debug)]
struct Claims<'a> {
iss: &'a str,
aud: &'a str,
exp: i64,
iat: i64,
#[serde(rename = "sub")]
subject: Option<&'a str>,
scope: String,
}
impl<'a> Claims<'a> {
fn new<T>(key: &'a ServiceAccountKey, scopes: &[T], subject: Option<&'a str>) -> Self
where
T: AsRef<str>,
{
let iat = OffsetDateTime::now_utc().unix_timestamp();
let expiry = iat + 3600 - 5; let scope = crate::helper::join(scopes, " ");
Claims {
iss: &key.client_email,
aud: &key.token_uri,
exp: expiry,
iat,
subject,
scope,
}
}
}
pub(crate) struct JWTSigner {
signer: Box<dyn rustls::sign::Signer>,
}
impl JWTSigner {
fn new(private_key: &str) -> Result<Self, io::Error> {
let key = decode_rsa_key(private_key)?;
let signing_key = crypto_provider::sign::RsaSigningKey::new(&key)
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Couldn't initialize signer"))?;
let signer = signing_key
.choose_scheme(&[rustls::SignatureScheme::RSA_PKCS1_SHA256])
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "Couldn't choose signing scheme")
})?;
Ok(JWTSigner { signer })
}
fn sign_claims(&self, claims: &Claims) -> Result<String, rustls::Error> {
let mut jwt_head = Self::encode_claims(claims);
let signature = self.signer.sign(jwt_head.as_bytes())?;
jwt_head.push('.');
append_base64(&signature, &mut jwt_head);
Ok(jwt_head)
}
fn encode_claims(claims: &Claims) -> String {
let mut head = String::new();
append_base64(GOOGLE_RS256_HEAD, &mut head);
head.push('.');
append_base64(&serde_json::to_string(&claims).unwrap(), &mut head);
head
}
}
pub struct ServiceAccountFlowOpts {
pub(crate) key: FlowOptsKey,
pub(crate) subject: Option<String>,
}
pub(crate) enum FlowOptsKey {
Path(PathBuf),
Key(Box<ServiceAccountKey>),
}
pub struct ServiceAccountFlow {
key: ServiceAccountKey,
subject: Option<String>,
signer: JWTSigner,
}
impl ServiceAccountFlow {
pub(crate) async fn new(opts: ServiceAccountFlowOpts) -> Result<Self, io::Error> {
let key = match opts.key {
FlowOptsKey::Path(path) => crate::read_service_account_key(path).await?,
FlowOptsKey::Key(key) => *key,
};
let signer = JWTSigner::new(&key.private_key)?;
Ok(ServiceAccountFlow {
key,
subject: opts.subject,
signer,
})
}
pub(crate) async fn token<C, T>(
&self,
hyper_client: &hyper_util::client::legacy::Client<C, String>,
scopes: &[T],
) -> Result<TokenInfo, Error>
where
T: AsRef<str>,
C: Connect + Clone + Send + Sync + 'static,
{
let claims = Claims::new(&self.key, scopes, self.subject.as_deref());
let signed = self.signer.sign_claims(&claims).map_err(|_| {
Error::LowLevelError(io::Error::new(
io::ErrorKind::Other,
"unable to sign claims",
))
})?;
let rqbody = form_urlencoded::Serializer::new(String::new())
.extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", signed.as_str())])
.finish();
let request = http::Request::post(&self.key.token_uri)
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(rqbody)
.unwrap();
log::debug!("requesting token from service account: {:?}", request);
let (head, body) = hyper_client.request(request).await?.into_parts();
let body = body.collect().await?.to_bytes();
log::debug!("received response; head: {:?}, body: {:?}", head, body);
TokenInfo::from_json(&body)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::helper::read_service_account_key;
const TEST_PRIVATE_KEY_PATH: &str = "examples/Sanguine-69411a0c0eea.json";
#[cfg(feature = "hyper-rustls")]
#[allow(dead_code)]
async fn test_service_account_e2e() {
let acc = ServiceAccountFlow::new(ServiceAccountFlowOpts {
key: FlowOptsKey::Path(TEST_PRIVATE_KEY_PATH.into()),
subject: None,
})
.await
.unwrap();
let client =
hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
.build(
hyper_rustls::HttpsConnectorBuilder::new()
.with_provider_and_native_roots(crypto_provider::default_provider())
.unwrap()
.https_only()
.enable_http1()
.enable_http2()
.build(),
);
println!(
"{:?}",
acc.token(&client, &["https://www.googleapis.com/auth/pubsub"])
.await
);
println!(
"{:?}",
acc.token(
&client,
&["https://some.scope/likely-to-hand-out-id-tokens"]
)
.await
);
}
#[tokio::test]
async fn test_jwt_initialize_claims() {
let key = read_service_account_key(TEST_PRIVATE_KEY_PATH)
.await
.unwrap();
let scopes = vec!["scope1", "scope2", "scope3"];
let claims = Claims::new(&key, &scopes, None);
assert_eq!(
claims.iss,
"oauth2-public-test@sanguine-rhythm-105020.iam.gserviceaccount.com".to_string()
);
assert_eq!(claims.scope, "scope1 scope2 scope3".to_string());
assert_eq!(
claims.aud,
"https://accounts.google.com/o/oauth2/token".to_string()
);
assert!(claims.exp > 1000000000);
assert!(claims.iat < claims.exp);
assert_eq!(claims.exp - claims.iat, 3595);
}
#[tokio::test]
async fn test_jwt_sign() {
let key = read_service_account_key(TEST_PRIVATE_KEY_PATH)
.await
.unwrap();
let scopes = vec!["scope1", "scope2", "scope3"];
let signer = JWTSigner::new(&key.private_key).unwrap();
let claims = Claims::new(&key, &scopes, None);
let signature = signer.sign_claims(&claims);
assert!(signature.is_ok());
let signature = signature.unwrap();
assert_eq!(
signature.split('.').next().unwrap(),
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
);
}
}