kube_client/client/auth/
oidc.rs

1use std::collections::HashMap;
2
3use super::TEN_SEC;
4use chrono::{TimeZone, Utc};
5use form_urlencoded::Serializer;
6use http::{
7    header::{HeaderValue, AUTHORIZATION, CONTENT_TYPE},
8    Method, Request, Uri, Version,
9};
10use http_body_util::BodyExt;
11use hyper_util::{
12    client::legacy::{connect::HttpConnector, Client},
13    rt::TokioExecutor,
14};
15use secrecy::{ExposeSecret, SecretString};
16use serde::{Deserialize, Deserializer};
17use serde_json::Number;
18
19/// Possible errors when handling OIDC authentication.
20pub mod errors {
21    use super::Oidc;
22    use http::{uri::InvalidUri, StatusCode};
23    use thiserror::Error;
24
25    /// Possible errors when extracting expiration time from an ID token.
26    #[derive(Error, Debug)]
27    pub enum IdTokenError {
28        /// Failed to extract payload from the ID token.
29        #[error("not a valid JWT token")]
30        InvalidFormat,
31        /// ID token payload is not properly encoded in base64.
32        #[error("failed to decode base64: {0}")]
33        InvalidBase64(
34            #[source]
35            #[from]
36            base64::DecodeError,
37        ),
38        /// ID token payload is not valid JSON object containing expiration timestamp.
39        #[error("failed to unmarshal JSON: {0}")]
40        InvalidJson(
41            #[source]
42            #[from]
43            serde_json::Error,
44        ),
45        /// Expiration timestamp extracted from the ID token payload is not valid.
46        #[error("invalid expiration timestamp")]
47        InvalidExpirationTimestamp,
48    }
49
50    /// Possible error when initializing the ID token refreshing.
51    #[derive(Error, Debug, Clone)]
52    pub enum RefreshInitError {
53        /// Missing field in the configuration.
54        #[error("missing field {0}")]
55        MissingField(&'static str),
56        /// Failed to create an HTTPS client.
57        #[cfg(feature = "openssl-tls")]
58        #[cfg_attr(docsrs, doc(cfg(feature = "openssl-tls")))]
59        #[error("failed to create OpenSSL HTTPS connector: {0}")]
60        CreateOpensslHttpsConnector(
61            #[source]
62            #[from]
63            openssl::error::ErrorStack,
64        ),
65        /// No valid native root CA certificates found
66        #[error("No valid native root CA certificates found")]
67        NoValidNativeRootCA,
68    }
69
70    /// Possible errors when using the refresh token.
71    #[derive(Error, Debug)]
72    pub enum RefreshError {
73        /// Failed to parse the provided issuer URL.
74        #[error("invalid URI: {0}")]
75        InvalidURI(
76            #[source]
77            #[from]
78            InvalidUri,
79        ),
80        /// [`hyper::Error`] occurred during refreshing.
81        #[error("hyper error: {0}")]
82        HyperError(
83            #[source]
84            #[from]
85            hyper::Error,
86        ),
87        /// [`hyper_util::client::legacy::Error`] occurred during refreshing.
88        #[error("hyper-util error: {0}")]
89        HyperUtilError(
90            #[source]
91            #[from]
92            hyper_util::client::legacy::Error,
93        ),
94        /// Failed to parse the metadata received from the provider.
95        #[error("invalid metadata received from the provider: {0}")]
96        InvalidMetadata(#[source] serde_json::Error),
97        /// Received an invalid status code from the provider.
98        #[error("request failed with status code: {0}")]
99        RequestFailed(StatusCode),
100        /// [`http::Error`] occurred during refreshing.
101        #[error("http error: {0}")]
102        HttpError(
103            #[source]
104            #[from]
105            http::Error,
106        ),
107        /// Failed to authorize with the provider.
108        #[error("failed to authorize with the provider using any of known authorization styles")]
109        AuthorizationFailure,
110        /// Failed to parse the token response from the provider.
111        #[error("invalid token response received from the provider: {0}")]
112        InvalidTokenResponse(#[source] serde_json::Error),
113        /// Token response from the provider did not contain an ID token.
114        #[error("no ID token received from the provider")]
115        NoIdTokenReceived,
116    }
117
118    /// Possible errors when dealing with OIDC.
119    #[derive(Error, Debug)]
120    pub enum Error {
121        /// Config did not contain the ID token.
122        #[error("missing field {}", Oidc::CONFIG_ID_TOKEN)]
123        IdTokenMissing,
124        /// Failed to retrieve expiration timestamp from the ID token.
125        #[error("invalid ID token: {0}")]
126        IdToken(
127            #[source]
128            #[from]
129            IdTokenError,
130        ),
131        /// Failed to initialize ID token refreshing.
132        #[error("ID token expired and refreshing is not possible: {0}")]
133        RefreshInit(
134            #[source]
135            #[from]
136            RefreshInitError,
137        ),
138        /// Failed to refresh the ID token.
139        #[error("ID token expired and refreshing failed: {0}")]
140        Refresh(
141            #[source]
142            #[from]
143            RefreshError,
144        ),
145    }
146}
147
148use base64::Engine as _;
149const JWT_BASE64_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
150    &base64::alphabet::URL_SAFE,
151    base64::engine::GeneralPurposeConfig::new()
152        .with_decode_allow_trailing_bits(true)
153        .with_decode_padding_mode(base64::engine::DecodePaddingMode::Indifferent),
154);
155use base64::engine::general_purpose::STANDARD as STANDARD_BASE64_ENGINE;
156
157#[derive(Debug)]
158pub struct Oidc {
159    id_token: SecretString,
160    refresher: Result<Refresher, errors::RefreshInitError>,
161}
162
163impl Oidc {
164    /// Config key for the ID token.
165    const CONFIG_ID_TOKEN: &'static str = "id-token";
166
167    /// Check whether the stored ID token can still be used.
168    fn token_valid(&self) -> Result<bool, errors::IdTokenError> {
169        let part = self
170            .id_token
171            .expose_secret()
172            .split('.')
173            .nth(1)
174            .ok_or(errors::IdTokenError::InvalidFormat)?;
175        let payload = JWT_BASE64_ENGINE.decode(part)?;
176        let expiry = serde_json::from_slice::<Claims>(&payload)?.expiry;
177        let timestamp = Utc
178            .timestamp_opt(expiry, 0)
179            .earliest()
180            .ok_or(errors::IdTokenError::InvalidExpirationTimestamp)?;
181
182        let valid = Utc::now() + TEN_SEC < timestamp;
183
184        Ok(valid)
185    }
186
187    /// Retrieve the ID token. If the stored ID token is or will soon be expired, try refreshing it first.
188    pub async fn id_token(&mut self) -> Result<String, errors::Error> {
189        if self.token_valid()? {
190            return Ok(self.id_token.expose_secret().to_string());
191        }
192
193        let id_token = self.refresher.as_mut().map_err(|e| e.clone())?.id_token().await?;
194
195        self.id_token = id_token.clone().into();
196
197        Ok(id_token)
198    }
199
200    /// Create an instance of this struct from the auth provider config.
201    pub fn from_config(config: &HashMap<String, String>) -> Result<Self, errors::Error> {
202        let id_token = config
203            .get(Self::CONFIG_ID_TOKEN)
204            .ok_or(errors::Error::IdTokenMissing)?
205            .clone()
206            .into();
207        let refresher = Refresher::from_config(config);
208
209        Ok(Self { id_token, refresher })
210    }
211}
212
213/// Claims extracted from the ID token. Only expiration time here is important.
214#[derive(Deserialize)]
215struct Claims {
216    #[serde(rename = "exp", deserialize_with = "deserialize_expiry")]
217    expiry: i64,
218}
219
220/// Deserialize expiration time from a JSON number.
221fn deserialize_expiry<'de, D: Deserializer<'de>>(deserializer: D) -> core::result::Result<i64, D::Error> {
222    let json_number = Number::deserialize(deserializer)?;
223
224    json_number
225        .as_i64()
226        .or_else(|| Some(json_number.as_f64()? as i64))
227        .ok_or(serde::de::Error::custom("cannot be casted to i64"))
228}
229
230/// Metadata retrieved from the provider. Only token endpoint here is important.
231#[derive(Deserialize)]
232struct Metadata {
233    token_endpoint: String,
234}
235
236/// Authorization styles used by different providers.
237/// Some providers require the authorization info in the header, some in the request body.
238/// Some providers reject requests when authorization info is passed in both.
239#[derive(Debug, Clone, Copy, PartialEq, Eq)]
240enum AuthStyle {
241    Header,
242    Params,
243}
244
245impl AuthStyle {
246    /// All known authorization styles.
247    const ALL: [Self; 2] = [Self::Header, Self::Params];
248}
249
250/// Token response from the provider. Only refresh token and id token here are important.
251#[derive(Deserialize)]
252struct TokenResponse {
253    refresh_token: Option<String>,
254    id_token: Option<String>,
255}
256
257#[cfg(not(any(feature = "rustls-tls", feature = "openssl-tls")))]
258compile_error!(
259    "At least one of rustls-tls or openssl-tls feature must be enabled to use refresh-oidc feature"
260);
261// Current TLS feature precedence when more than one are set:
262// 1. rustls-tls
263// 2. openssl-tls
264#[cfg(feature = "rustls-tls")]
265type HttpsConnector = hyper_rustls::HttpsConnector<HttpConnector>;
266#[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
267type HttpsConnector = hyper_openssl::HttpsConnector<HttpConnector>;
268
269/// Struct for refreshing the ID token with the refresh token.
270#[derive(Debug)]
271struct Refresher {
272    issuer: String,
273    /// Token endpoint exposed by the provider.
274    /// Retrieved from the provider metadata with the first refresh request.
275    token_endpoint: Option<String>,
276    /// Refresh token used in the refresh requests.
277    /// Updated when a new refresh token is returned by the provider.
278    refresh_token: SecretString,
279    client_id: SecretString,
280    client_secret: SecretString,
281    https_client: Client<HttpsConnector, String>,
282    /// Authorization style used by the provider.
283    /// Determined with the first refresh request by trying all known styles.
284    auth_style: Option<AuthStyle>,
285}
286
287impl Refresher {
288    /// Config key for the client ID.
289    const CONFIG_CLIENT_ID: &'static str = "client-id";
290    /// Config key for the client secret.
291    const CONFIG_CLIENT_SECRET: &'static str = "client-secret";
292    /// Config key for the issuer url.
293    const CONFIG_ISSUER_URL: &'static str = "idp-issuer-url";
294    /// Config key for the refresh token.
295    const CONFIG_REFRESH_TOKEN: &'static str = "refresh-token";
296
297    /// Create a new instance of this struct from the provider config.
298    fn from_config(config: &HashMap<String, String>) -> Result<Self, errors::RefreshInitError> {
299        let get_field = |name: &'static str| {
300            config
301                .get(name)
302                .cloned()
303                .ok_or(errors::RefreshInitError::MissingField(name))
304        };
305
306        let issuer = get_field(Self::CONFIG_ISSUER_URL)?;
307        let refresh_token = get_field(Self::CONFIG_REFRESH_TOKEN)?.into();
308        let client_id = get_field(Self::CONFIG_CLIENT_ID)?.into();
309        let client_secret = get_field(Self::CONFIG_CLIENT_SECRET)?.into();
310
311        #[cfg(all(feature = "rustls-tls", feature = "aws-lc-rs"))]
312        {
313            if rustls::crypto::CryptoProvider::get_default().is_none() {
314                // the only error here is if it's been initialized in between: we can ignore it
315                // since our semantic is only to set the default value if it does not exist.
316                let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
317            }
318        }
319
320        #[cfg(all(feature = "rustls-tls", not(feature = "webpki-roots")))]
321        let https = hyper_rustls::HttpsConnectorBuilder::new()
322            .with_native_roots()
323            .map_err(|_| errors::RefreshInitError::NoValidNativeRootCA)?
324            .https_only()
325            .enable_http1()
326            .build();
327        #[cfg(all(feature = "rustls-tls", feature = "webpki-roots"))]
328        let https = hyper_rustls::HttpsConnectorBuilder::new()
329            .with_webpki_roots()
330            .https_only()
331            .enable_http1()
332            .build();
333        #[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
334        let https = hyper_openssl::HttpsConnector::new()?;
335
336        let https_client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(https);
337
338        Ok(Self {
339            issuer,
340            token_endpoint: None,
341            refresh_token,
342            client_id,
343            client_secret,
344            https_client,
345            auth_style: None,
346        })
347    }
348
349    /// If the token endpoint is not yet cached in this struct, extract it from the provider metadata and store in the cache.
350    /// Provider metadata is retrieved from a well-known path.
351    async fn token_endpoint(&mut self) -> Result<String, errors::RefreshError> {
352        if let Some(endpoint) = self.token_endpoint.clone() {
353            return Ok(endpoint);
354        }
355
356        let discovery = format!("{}/.well-known/openid-configuration", self.issuer).parse::<Uri>()?;
357        let response = self.https_client.get(discovery).await?;
358
359        if response.status().is_success() {
360            let body = response.into_body().collect().await?.to_bytes();
361            let metadata = serde_json::from_slice::<Metadata>(body.as_ref())
362                .map_err(errors::RefreshError::InvalidMetadata)?;
363
364            self.token_endpoint.replace(metadata.token_endpoint.clone());
365
366            Ok(metadata.token_endpoint)
367        } else {
368            Err(errors::RefreshError::RequestFailed(response.status()))
369        }
370    }
371
372    /// Prepare a token request to the provider.
373    fn token_request(
374        &self,
375        endpoint: &str,
376        auth_style: AuthStyle,
377    ) -> Result<Request<String>, errors::RefreshError> {
378        let mut builder = Request::builder()
379            .uri(endpoint)
380            .method(Method::POST)
381            .header(
382                CONTENT_TYPE,
383                HeaderValue::from_static("application/x-www-form-urlencoded"),
384            )
385            .version(Version::HTTP_11);
386        let mut params = vec![
387            ("grant_type", "refresh_token"),
388            ("refresh_token", self.refresh_token.expose_secret()),
389        ];
390
391        match auth_style {
392            AuthStyle::Header => {
393                builder = builder.header(
394                    AUTHORIZATION,
395                    format!(
396                        "Basic {}",
397                        STANDARD_BASE64_ENGINE.encode(format!(
398                            "{}:{}",
399                            self.client_id.expose_secret(),
400                            self.client_secret.expose_secret()
401                        ))
402                    ),
403                );
404            }
405            AuthStyle::Params => {
406                params.extend([
407                    ("client_id", self.client_id.expose_secret()),
408                    ("client_secret", self.client_secret.expose_secret()),
409                ]);
410            }
411        };
412
413        let body = Serializer::new(String::new()).extend_pairs(params).finish();
414
415        builder.body(body).map_err(Into::into)
416    }
417
418    /// Fetch a new ID token from the provider.
419    async fn id_token(&mut self) -> Result<String, errors::RefreshError> {
420        let token_endpoint = self.token_endpoint().await?;
421
422        let response = match self.auth_style {
423            Some(style) => {
424                let request = self.token_request(&token_endpoint, style)?;
425                self.https_client.request(request).await?
426            }
427            None => {
428                let mut ok_response = None;
429
430                for style in AuthStyle::ALL {
431                    let request = self.token_request(&token_endpoint, style)?;
432                    let response = self.https_client.request(request).await?;
433                    if response.status().is_success() {
434                        ok_response.replace(response);
435                        self.auth_style.replace(style);
436                        break;
437                    }
438                }
439
440                ok_response.ok_or(errors::RefreshError::AuthorizationFailure)?
441            }
442        };
443
444        if !response.status().is_success() {
445            return Err(errors::RefreshError::RequestFailed(response.status()));
446        }
447
448        let body = response.into_body().collect().await?.to_bytes();
449        let token_response = serde_json::from_slice::<TokenResponse>(body.as_ref())
450            .map_err(errors::RefreshError::InvalidTokenResponse)?;
451
452        if let Some(token) = token_response.refresh_token {
453            self.refresh_token = token.into();
454        }
455
456        token_response
457            .id_token
458            .ok_or(errors::RefreshError::NoIdTokenReceived)
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[test]
467    fn token_valid() {
468        let mut oidc = Oidc {
469            id_token: String::new().into(),
470            refresher: Err(errors::RefreshInitError::MissingField(
471                Refresher::CONFIG_REFRESH_TOKEN,
472            )),
473        };
474
475        // Proper JWT expiring at 2123-06-28T15:18:12.629Z
476        let token_valid = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9\
477.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2ODc5NjU0NTIsImV4cCI6NDg0MzYzOTA5MiwiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoianJvY2tldEBleGFtcGxlLmNvbSIsIkVtYWlsIjoiYmVlQGV4YW1wbGUuY29tIn0\
478.GKTkPMywcNQv0n01iBfv_A6VuCCCcAe72RhP0OrZsQM";
479        // Proper JWT expired at 2023-06-28T15:19:53.421Z
480        let token_expired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9\
481.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2ODc5NjU0NTIsImV4cCI6MTY4Nzk2NTU5MywiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoianJvY2tldEBleGFtcGxlLmNvbSIsIkVtYWlsIjoiYmVlQGV4YW1wbGUuY29tIn0\
482.zTDnfI_zXIa6yPKY_ZE8r6GoLK7Syj-URcTU5_ryv1M";
483
484        oidc.id_token = token_valid.to_string().into();
485        assert!(oidc.token_valid().expect("proper token failed validation"));
486
487        oidc.id_token = token_expired.to_string().into();
488        assert!(!oidc.token_valid().expect("proper token failed validation"));
489
490        let malformed_token = token_expired.split_once('.').unwrap().0.to_string();
491        oidc.id_token = malformed_token.into();
492        oidc.token_valid().expect_err("malformed token passed validation");
493
494        let invalid_base64_token = token_valid
495            .split_once('.')
496            .map(|(prefix, suffix)| format!("{}.?{}", prefix, suffix))
497            .unwrap();
498        oidc.id_token = invalid_base64_token.into();
499        oidc.token_valid()
500            .expect_err("token with invalid base64 encoding passed validation");
501
502        let invalid_claims = [("sub", "jrocket@example.com"), ("aud", "www.example.com")]
503            .into_iter()
504            .collect::<HashMap<_, _>>();
505        let invalid_claims_token = format!(
506            "{}.{}.{}",
507            token_valid.split_once('.').unwrap().0,
508            JWT_BASE64_ENGINE.encode(serde_json::to_string(&invalid_claims).unwrap()),
509            token_valid.rsplit_once('.').unwrap().1,
510        );
511        oidc.id_token = invalid_claims_token.into();
512        oidc.token_valid()
513            .expect_err("token without expiration timestamp passed validation");
514    }
515
516    #[cfg(any(feature = "openssl-tls", feature = "rustls-tls"))]
517    #[test]
518    fn from_minimal_config() {
519        let minimal_config = [(Oidc::CONFIG_ID_TOKEN.into(), "some_id_token".into())]
520            .into_iter()
521            .collect();
522
523        let oidc = Oidc::from_config(&minimal_config)
524            .expect("failed to create oidc from minimal config (only id-token)");
525        assert_eq!(oidc.id_token.expose_secret(), "some_id_token");
526        assert!(oidc.refresher.is_err());
527    }
528
529    #[cfg(any(feature = "openssl-tls", feature = "rustls-tls"))]
530    #[test]
531    fn from_full_config() {
532        let full_config = [
533            (Oidc::CONFIG_ID_TOKEN.into(), "some_id_token".into()),
534            (Refresher::CONFIG_ISSUER_URL.into(), "some_issuer".into()),
535            (
536                Refresher::CONFIG_REFRESH_TOKEN.into(),
537                "some_refresh_token".into(),
538            ),
539            (Refresher::CONFIG_CLIENT_ID.into(), "some_client_id".into()),
540            (
541                Refresher::CONFIG_CLIENT_SECRET.into(),
542                "some_client_secret".into(),
543            ),
544        ]
545        .into_iter()
546        .collect();
547
548        let oidc = Oidc::from_config(&full_config).expect("failed to create oidc from full config");
549        assert_eq!(oidc.id_token.expose_secret(), "some_id_token");
550        let refresher = oidc
551            .refresher
552            .as_ref()
553            .expect("failed to create oidc refresher from full config");
554        assert_eq!(refresher.issuer, "some_issuer");
555        assert_eq!(refresher.token_endpoint, None);
556        assert_eq!(refresher.refresh_token.expose_secret(), "some_refresh_token");
557        assert_eq!(refresher.client_id.expose_secret(), "some_client_id");
558        assert_eq!(refresher.client_secret.expose_secret(), "some_client_secret");
559        assert_eq!(refresher.auth_style, None);
560    }
561}