1use std::collections::HashMap;
2
3use super::TEN_SEC;
4use chrono::{TimeZone, Utc};
5use form_urlencoded::Serializer;
6use http::{
7 header::{HeaderValue, AUTHORIZATION, CONTENT_TYPE},
8 Method, Request, Uri, Version,
9};
10use http_body_util::BodyExt;
11use hyper_util::{
12 client::legacy::{connect::HttpConnector, Client},
13 rt::TokioExecutor,
14};
15use secrecy::{ExposeSecret, SecretString};
16use serde::{Deserialize, Deserializer};
17use serde_json::Number;
18
19pub mod errors {
21 use super::Oidc;
22 use http::{uri::InvalidUri, StatusCode};
23 use thiserror::Error;
24
25 #[derive(Error, Debug)]
27 pub enum IdTokenError {
28 #[error("not a valid JWT token")]
30 InvalidFormat,
31 #[error("failed to decode base64: {0}")]
33 InvalidBase64(
34 #[source]
35 #[from]
36 base64::DecodeError,
37 ),
38 #[error("failed to unmarshal JSON: {0}")]
40 InvalidJson(
41 #[source]
42 #[from]
43 serde_json::Error,
44 ),
45 #[error("invalid expiration timestamp")]
47 InvalidExpirationTimestamp,
48 }
49
50 #[derive(Error, Debug, Clone)]
52 pub enum RefreshInitError {
53 #[error("missing field {0}")]
55 MissingField(&'static str),
56 #[cfg(feature = "openssl-tls")]
58 #[cfg_attr(docsrs, doc(cfg(feature = "openssl-tls")))]
59 #[error("failed to create OpenSSL HTTPS connector: {0}")]
60 CreateOpensslHttpsConnector(
61 #[source]
62 #[from]
63 openssl::error::ErrorStack,
64 ),
65 #[error("No valid native root CA certificates found")]
67 NoValidNativeRootCA,
68 }
69
70 #[derive(Error, Debug)]
72 pub enum RefreshError {
73 #[error("invalid URI: {0}")]
75 InvalidURI(
76 #[source]
77 #[from]
78 InvalidUri,
79 ),
80 #[error("hyper error: {0}")]
82 HyperError(
83 #[source]
84 #[from]
85 hyper::Error,
86 ),
87 #[error("hyper-util error: {0}")]
89 HyperUtilError(
90 #[source]
91 #[from]
92 hyper_util::client::legacy::Error,
93 ),
94 #[error("invalid metadata received from the provider: {0}")]
96 InvalidMetadata(#[source] serde_json::Error),
97 #[error("request failed with status code: {0}")]
99 RequestFailed(StatusCode),
100 #[error("http error: {0}")]
102 HttpError(
103 #[source]
104 #[from]
105 http::Error,
106 ),
107 #[error("failed to authorize with the provider using any of known authorization styles")]
109 AuthorizationFailure,
110 #[error("invalid token response received from the provider: {0}")]
112 InvalidTokenResponse(#[source] serde_json::Error),
113 #[error("no ID token received from the provider")]
115 NoIdTokenReceived,
116 }
117
118 #[derive(Error, Debug)]
120 pub enum Error {
121 #[error("missing field {}", Oidc::CONFIG_ID_TOKEN)]
123 IdTokenMissing,
124 #[error("invalid ID token: {0}")]
126 IdToken(
127 #[source]
128 #[from]
129 IdTokenError,
130 ),
131 #[error("ID token expired and refreshing is not possible: {0}")]
133 RefreshInit(
134 #[source]
135 #[from]
136 RefreshInitError,
137 ),
138 #[error("ID token expired and refreshing failed: {0}")]
140 Refresh(
141 #[source]
142 #[from]
143 RefreshError,
144 ),
145 }
146}
147
148use base64::Engine as _;
149const JWT_BASE64_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
150 &base64::alphabet::URL_SAFE,
151 base64::engine::GeneralPurposeConfig::new()
152 .with_decode_allow_trailing_bits(true)
153 .with_decode_padding_mode(base64::engine::DecodePaddingMode::Indifferent),
154);
155use base64::engine::general_purpose::STANDARD as STANDARD_BASE64_ENGINE;
156
157#[derive(Debug)]
158pub struct Oidc {
159 id_token: SecretString,
160 refresher: Result<Refresher, errors::RefreshInitError>,
161}
162
163impl Oidc {
164 const CONFIG_ID_TOKEN: &'static str = "id-token";
166
167 fn token_valid(&self) -> Result<bool, errors::IdTokenError> {
169 let part = self
170 .id_token
171 .expose_secret()
172 .split('.')
173 .nth(1)
174 .ok_or(errors::IdTokenError::InvalidFormat)?;
175 let payload = JWT_BASE64_ENGINE.decode(part)?;
176 let expiry = serde_json::from_slice::<Claims>(&payload)?.expiry;
177 let timestamp = Utc
178 .timestamp_opt(expiry, 0)
179 .earliest()
180 .ok_or(errors::IdTokenError::InvalidExpirationTimestamp)?;
181
182 let valid = Utc::now() + TEN_SEC < timestamp;
183
184 Ok(valid)
185 }
186
187 pub async fn id_token(&mut self) -> Result<String, errors::Error> {
189 if self.token_valid()? {
190 return Ok(self.id_token.expose_secret().to_string());
191 }
192
193 let id_token = self.refresher.as_mut().map_err(|e| e.clone())?.id_token().await?;
194
195 self.id_token = id_token.clone().into();
196
197 Ok(id_token)
198 }
199
200 pub fn from_config(config: &HashMap<String, String>) -> Result<Self, errors::Error> {
202 let id_token = config
203 .get(Self::CONFIG_ID_TOKEN)
204 .ok_or(errors::Error::IdTokenMissing)?
205 .clone()
206 .into();
207 let refresher = Refresher::from_config(config);
208
209 Ok(Self { id_token, refresher })
210 }
211}
212
213#[derive(Deserialize)]
215struct Claims {
216 #[serde(rename = "exp", deserialize_with = "deserialize_expiry")]
217 expiry: i64,
218}
219
220fn deserialize_expiry<'de, D: Deserializer<'de>>(deserializer: D) -> core::result::Result<i64, D::Error> {
222 let json_number = Number::deserialize(deserializer)?;
223
224 json_number
225 .as_i64()
226 .or_else(|| Some(json_number.as_f64()? as i64))
227 .ok_or(serde::de::Error::custom("cannot be casted to i64"))
228}
229
230#[derive(Deserialize)]
232struct Metadata {
233 token_endpoint: String,
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq)]
240enum AuthStyle {
241 Header,
242 Params,
243}
244
245impl AuthStyle {
246 const ALL: [Self; 2] = [Self::Header, Self::Params];
248}
249
250#[derive(Deserialize)]
252struct TokenResponse {
253 refresh_token: Option<String>,
254 id_token: Option<String>,
255}
256
257#[cfg(not(any(feature = "rustls-tls", feature = "openssl-tls")))]
258compile_error!(
259 "At least one of rustls-tls or openssl-tls feature must be enabled to use refresh-oidc feature"
260);
261#[cfg(feature = "rustls-tls")]
265type HttpsConnector = hyper_rustls::HttpsConnector<HttpConnector>;
266#[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
267type HttpsConnector = hyper_openssl::HttpsConnector<HttpConnector>;
268
269#[derive(Debug)]
271struct Refresher {
272 issuer: String,
273 token_endpoint: Option<String>,
276 refresh_token: SecretString,
279 client_id: SecretString,
280 client_secret: SecretString,
281 https_client: Client<HttpsConnector, String>,
282 auth_style: Option<AuthStyle>,
285}
286
287impl Refresher {
288 const CONFIG_CLIENT_ID: &'static str = "client-id";
290 const CONFIG_CLIENT_SECRET: &'static str = "client-secret";
292 const CONFIG_ISSUER_URL: &'static str = "idp-issuer-url";
294 const CONFIG_REFRESH_TOKEN: &'static str = "refresh-token";
296
297 fn from_config(config: &HashMap<String, String>) -> Result<Self, errors::RefreshInitError> {
299 let get_field = |name: &'static str| {
300 config
301 .get(name)
302 .cloned()
303 .ok_or(errors::RefreshInitError::MissingField(name))
304 };
305
306 let issuer = get_field(Self::CONFIG_ISSUER_URL)?;
307 let refresh_token = get_field(Self::CONFIG_REFRESH_TOKEN)?.into();
308 let client_id = get_field(Self::CONFIG_CLIENT_ID)?.into();
309 let client_secret = get_field(Self::CONFIG_CLIENT_SECRET)?.into();
310
311 #[cfg(all(feature = "rustls-tls", feature = "aws-lc-rs"))]
312 {
313 if rustls::crypto::CryptoProvider::get_default().is_none() {
314 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
317 }
318 }
319
320 #[cfg(all(feature = "rustls-tls", not(feature = "webpki-roots")))]
321 let https = hyper_rustls::HttpsConnectorBuilder::new()
322 .with_native_roots()
323 .map_err(|_| errors::RefreshInitError::NoValidNativeRootCA)?
324 .https_only()
325 .enable_http1()
326 .build();
327 #[cfg(all(feature = "rustls-tls", feature = "webpki-roots"))]
328 let https = hyper_rustls::HttpsConnectorBuilder::new()
329 .with_webpki_roots()
330 .https_only()
331 .enable_http1()
332 .build();
333 #[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
334 let https = hyper_openssl::HttpsConnector::new()?;
335
336 let https_client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(https);
337
338 Ok(Self {
339 issuer,
340 token_endpoint: None,
341 refresh_token,
342 client_id,
343 client_secret,
344 https_client,
345 auth_style: None,
346 })
347 }
348
349 async fn token_endpoint(&mut self) -> Result<String, errors::RefreshError> {
352 if let Some(endpoint) = self.token_endpoint.clone() {
353 return Ok(endpoint);
354 }
355
356 let discovery = format!("{}/.well-known/openid-configuration", self.issuer).parse::<Uri>()?;
357 let response = self.https_client.get(discovery).await?;
358
359 if response.status().is_success() {
360 let body = response.into_body().collect().await?.to_bytes();
361 let metadata = serde_json::from_slice::<Metadata>(body.as_ref())
362 .map_err(errors::RefreshError::InvalidMetadata)?;
363
364 self.token_endpoint.replace(metadata.token_endpoint.clone());
365
366 Ok(metadata.token_endpoint)
367 } else {
368 Err(errors::RefreshError::RequestFailed(response.status()))
369 }
370 }
371
372 fn token_request(
374 &self,
375 endpoint: &str,
376 auth_style: AuthStyle,
377 ) -> Result<Request<String>, errors::RefreshError> {
378 let mut builder = Request::builder()
379 .uri(endpoint)
380 .method(Method::POST)
381 .header(
382 CONTENT_TYPE,
383 HeaderValue::from_static("application/x-www-form-urlencoded"),
384 )
385 .version(Version::HTTP_11);
386 let mut params = vec![
387 ("grant_type", "refresh_token"),
388 ("refresh_token", self.refresh_token.expose_secret()),
389 ];
390
391 match auth_style {
392 AuthStyle::Header => {
393 builder = builder.header(
394 AUTHORIZATION,
395 format!(
396 "Basic {}",
397 STANDARD_BASE64_ENGINE.encode(format!(
398 "{}:{}",
399 self.client_id.expose_secret(),
400 self.client_secret.expose_secret()
401 ))
402 ),
403 );
404 }
405 AuthStyle::Params => {
406 params.extend([
407 ("client_id", self.client_id.expose_secret()),
408 ("client_secret", self.client_secret.expose_secret()),
409 ]);
410 }
411 };
412
413 let body = Serializer::new(String::new()).extend_pairs(params).finish();
414
415 builder.body(body).map_err(Into::into)
416 }
417
418 async fn id_token(&mut self) -> Result<String, errors::RefreshError> {
420 let token_endpoint = self.token_endpoint().await?;
421
422 let response = match self.auth_style {
423 Some(style) => {
424 let request = self.token_request(&token_endpoint, style)?;
425 self.https_client.request(request).await?
426 }
427 None => {
428 let mut ok_response = None;
429
430 for style in AuthStyle::ALL {
431 let request = self.token_request(&token_endpoint, style)?;
432 let response = self.https_client.request(request).await?;
433 if response.status().is_success() {
434 ok_response.replace(response);
435 self.auth_style.replace(style);
436 break;
437 }
438 }
439
440 ok_response.ok_or(errors::RefreshError::AuthorizationFailure)?
441 }
442 };
443
444 if !response.status().is_success() {
445 return Err(errors::RefreshError::RequestFailed(response.status()));
446 }
447
448 let body = response.into_body().collect().await?.to_bytes();
449 let token_response = serde_json::from_slice::<TokenResponse>(body.as_ref())
450 .map_err(errors::RefreshError::InvalidTokenResponse)?;
451
452 if let Some(token) = token_response.refresh_token {
453 self.refresh_token = token.into();
454 }
455
456 token_response
457 .id_token
458 .ok_or(errors::RefreshError::NoIdTokenReceived)
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn token_valid() {
468 let mut oidc = Oidc {
469 id_token: String::new().into(),
470 refresher: Err(errors::RefreshInitError::MissingField(
471 Refresher::CONFIG_REFRESH_TOKEN,
472 )),
473 };
474
475 let token_valid = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9\
477.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2ODc5NjU0NTIsImV4cCI6NDg0MzYzOTA5MiwiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoianJvY2tldEBleGFtcGxlLmNvbSIsIkVtYWlsIjoiYmVlQGV4YW1wbGUuY29tIn0\
478.GKTkPMywcNQv0n01iBfv_A6VuCCCcAe72RhP0OrZsQM";
479 let token_expired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9\
481.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2ODc5NjU0NTIsImV4cCI6MTY4Nzk2NTU5MywiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoianJvY2tldEBleGFtcGxlLmNvbSIsIkVtYWlsIjoiYmVlQGV4YW1wbGUuY29tIn0\
482.zTDnfI_zXIa6yPKY_ZE8r6GoLK7Syj-URcTU5_ryv1M";
483
484 oidc.id_token = token_valid.to_string().into();
485 assert!(oidc.token_valid().expect("proper token failed validation"));
486
487 oidc.id_token = token_expired.to_string().into();
488 assert!(!oidc.token_valid().expect("proper token failed validation"));
489
490 let malformed_token = token_expired.split_once('.').unwrap().0.to_string();
491 oidc.id_token = malformed_token.into();
492 oidc.token_valid().expect_err("malformed token passed validation");
493
494 let invalid_base64_token = token_valid
495 .split_once('.')
496 .map(|(prefix, suffix)| format!("{}.?{}", prefix, suffix))
497 .unwrap();
498 oidc.id_token = invalid_base64_token.into();
499 oidc.token_valid()
500 .expect_err("token with invalid base64 encoding passed validation");
501
502 let invalid_claims = [("sub", "jrocket@example.com"), ("aud", "www.example.com")]
503 .into_iter()
504 .collect::<HashMap<_, _>>();
505 let invalid_claims_token = format!(
506 "{}.{}.{}",
507 token_valid.split_once('.').unwrap().0,
508 JWT_BASE64_ENGINE.encode(serde_json::to_string(&invalid_claims).unwrap()),
509 token_valid.rsplit_once('.').unwrap().1,
510 );
511 oidc.id_token = invalid_claims_token.into();
512 oidc.token_valid()
513 .expect_err("token without expiration timestamp passed validation");
514 }
515
516 #[cfg(any(feature = "openssl-tls", feature = "rustls-tls"))]
517 #[test]
518 fn from_minimal_config() {
519 let minimal_config = [(Oidc::CONFIG_ID_TOKEN.into(), "some_id_token".into())]
520 .into_iter()
521 .collect();
522
523 let oidc = Oidc::from_config(&minimal_config)
524 .expect("failed to create oidc from minimal config (only id-token)");
525 assert_eq!(oidc.id_token.expose_secret(), "some_id_token");
526 assert!(oidc.refresher.is_err());
527 }
528
529 #[cfg(any(feature = "openssl-tls", feature = "rustls-tls"))]
530 #[test]
531 fn from_full_config() {
532 let full_config = [
533 (Oidc::CONFIG_ID_TOKEN.into(), "some_id_token".into()),
534 (Refresher::CONFIG_ISSUER_URL.into(), "some_issuer".into()),
535 (
536 Refresher::CONFIG_REFRESH_TOKEN.into(),
537 "some_refresh_token".into(),
538 ),
539 (Refresher::CONFIG_CLIENT_ID.into(), "some_client_id".into()),
540 (
541 Refresher::CONFIG_CLIENT_SECRET.into(),
542 "some_client_secret".into(),
543 ),
544 ]
545 .into_iter()
546 .collect();
547
548 let oidc = Oidc::from_config(&full_config).expect("failed to create oidc from full config");
549 assert_eq!(oidc.id_token.expose_secret(), "some_id_token");
550 let refresher = oidc
551 .refresher
552 .as_ref()
553 .expect("failed to create oidc refresher from full config");
554 assert_eq!(refresher.issuer, "some_issuer");
555 assert_eq!(refresher.token_endpoint, None);
556 assert_eq!(refresher.refresh_token.expose_secret(), "some_refresh_token");
557 assert_eq!(refresher.client_id.expose_secret(), "some_client_id");
558 assert_eq!(refresher.client_secret.expose_secret(), "some_client_secret");
559 assert_eq!(refresher.auth_style, None);
560 }
561}