1use crate::aws_lc::{HKDF_expand, HKDF};
41use crate::error::Unspecified;
42use crate::fips::indicator_check;
43use crate::{digest, hmac};
44use alloc::sync::Arc;
45use core::fmt;
46use zeroize::Zeroize;
47
48#[derive(Clone, Copy, Debug, Eq, PartialEq)]
50pub struct Algorithm(hmac::Algorithm);
51
52impl Algorithm {
53 #[inline]
55 #[must_use]
56 pub fn hmac_algorithm(&self) -> hmac::Algorithm {
57 self.0
58 }
59}
60
61pub static HKDF_SHA1_FOR_LEGACY_USE_ONLY: Algorithm =
63 Algorithm(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY);
64
65pub static HKDF_SHA256: Algorithm = Algorithm(hmac::HMAC_SHA256);
67
68pub static HKDF_SHA384: Algorithm = Algorithm(hmac::HMAC_SHA384);
70
71pub static HKDF_SHA512: Algorithm = Algorithm(hmac::HMAC_SHA512);
73
74const MAX_HKDF_SALT_LEN: usize = 80;
77
78const HKDF_INFO_DEFAULT_CAPACITY_LEN: usize = 300;
82
83const MAX_HKDF_PRK_LEN: usize = digest::MAX_OUTPUT_LEN;
86
87impl KeyType for Algorithm {
88 fn len(&self) -> usize {
89 self.0.digest_algorithm().output_len
90 }
91}
92
93pub struct Salt {
95 algorithm: Algorithm,
96 salt_bytes: [u8; MAX_HKDF_SALT_LEN],
97 salt_len: usize,
98}
99
100#[allow(clippy::missing_fields_in_debug)]
101impl fmt::Debug for Salt {
102 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
103 f.debug_struct("hkdf::Salt")
104 .field("algorithm", &self.algorithm.0)
105 .finish()
106 }
107}
108
109impl Drop for Salt {
110 fn drop(&mut self) {
111 self.salt_bytes.zeroize();
112 }
113}
114
115impl Salt {
116 #[must_use]
134 pub fn new(algorithm: Algorithm, value: &[u8]) -> Self {
135 Salt::try_new(algorithm, value).expect("Salt length limit exceeded.")
136 }
137
138 fn try_new(algorithm: Algorithm, value: &[u8]) -> Result<Salt, Unspecified> {
139 let salt_len = value.len();
140 if salt_len > MAX_HKDF_SALT_LEN {
141 return Err(Unspecified);
142 }
143 let mut salt_bytes = [0u8; MAX_HKDF_SALT_LEN];
144 salt_bytes[0..salt_len].copy_from_slice(value);
145 Ok(Self {
146 algorithm,
147 salt_bytes,
148 salt_len,
149 })
150 }
151
152 #[inline]
159 #[must_use]
160 pub fn extract(&self, secret: &[u8]) -> Prk {
161 Prk {
162 algorithm: self.algorithm,
163 mode: PrkMode::ExtractExpand {
164 secret: Arc::from(ZeroizeBoxSlice::from(secret)),
165 salt: self.salt_bytes,
166 salt_len: self.salt_len,
167 },
168 }
169 }
170
171 #[inline]
173 #[must_use]
174 pub fn algorithm(&self) -> Algorithm {
175 Algorithm(self.algorithm.hmac_algorithm())
176 }
177}
178
179#[allow(clippy::assertions_on_constants)]
180const _: () = assert!(MAX_HKDF_PRK_LEN <= MAX_HKDF_SALT_LEN);
181
182impl From<Okm<'_, Algorithm>> for Salt {
183 fn from(okm: Okm<'_, Algorithm>) -> Self {
184 let algorithm = okm.prk.algorithm;
185 let mut salt_bytes = [0u8; MAX_HKDF_SALT_LEN];
186 let salt_len = okm.len().len();
187 okm.fill(&mut salt_bytes[..salt_len]).unwrap();
188 Self {
189 algorithm,
190 salt_bytes,
191 salt_len,
192 }
193 }
194}
195
196#[allow(clippy::len_without_is_empty)]
198pub trait KeyType {
199 fn len(&self) -> usize;
201}
202
203#[derive(Clone)]
204enum PrkMode {
205 Expand {
206 key_bytes: [u8; MAX_HKDF_PRK_LEN],
207 key_len: usize,
208 },
209 ExtractExpand {
210 secret: Arc<ZeroizeBoxSlice<u8>>,
211 salt: [u8; MAX_HKDF_SALT_LEN],
212 salt_len: usize,
213 },
214}
215
216impl PrkMode {
217 fn fill(&self, algorithm: Algorithm, out: &mut [u8], info: &[u8]) -> Result<(), Unspecified> {
218 let digest = *digest::match_digest_type(&algorithm.0.digest_algorithm().id);
219
220 match &self {
221 PrkMode::Expand { key_bytes, key_len } => unsafe {
222 if 1 != indicator_check!(HKDF_expand(
223 out.as_mut_ptr(),
224 out.len(),
225 digest,
226 key_bytes.as_ptr(),
227 *key_len,
228 info.as_ptr(),
229 info.len(),
230 )) {
231 return Err(Unspecified);
232 }
233 },
234 PrkMode::ExtractExpand {
235 secret,
236 salt,
237 salt_len,
238 } => {
239 if 1 != indicator_check!(unsafe {
240 HKDF(
241 out.as_mut_ptr(),
242 out.len(),
243 digest,
244 secret.as_ptr(),
245 secret.len(),
246 salt.as_ptr(),
247 *salt_len,
248 info.as_ptr(),
249 info.len(),
250 )
251 }) {
252 return Err(Unspecified);
253 }
254 }
255 }
256
257 Ok(())
258 }
259}
260
261impl fmt::Debug for PrkMode {
262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263 match self {
264 Self::Expand { .. } => f.debug_struct("Expand").finish_non_exhaustive(),
265 Self::ExtractExpand { .. } => f.debug_struct("ExtractExpand").finish_non_exhaustive(),
266 }
267 }
268}
269
270struct ZeroizeBoxSlice<T: Zeroize>(Box<[T]>);
271
272impl<T: Zeroize> core::ops::Deref for ZeroizeBoxSlice<T> {
273 type Target = [T];
274
275 fn deref(&self) -> &Self::Target {
276 &self.0
277 }
278}
279
280impl<T: Clone + Zeroize> From<&[T]> for ZeroizeBoxSlice<T> {
281 fn from(value: &[T]) -> Self {
282 Self(Vec::from(value).into_boxed_slice())
283 }
284}
285
286impl<T: Zeroize> Drop for ZeroizeBoxSlice<T> {
287 fn drop(&mut self) {
288 self.0.zeroize();
289 }
290}
291
292#[derive(Clone)]
294pub struct Prk {
295 algorithm: Algorithm,
296 mode: PrkMode,
297}
298
299#[allow(clippy::missing_fields_in_debug)]
300impl fmt::Debug for Prk {
301 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
302 f.debug_struct("hkdf::Prk")
303 .field("algorithm", &self.algorithm.0)
304 .field("mode", &self.mode)
305 .finish()
306 }
307}
308
309impl Prk {
310 #[must_use]
324 pub fn new_less_safe(algorithm: Algorithm, value: &[u8]) -> Self {
325 Prk::try_new_less_safe(algorithm, value).expect("Prk length limit exceeded.")
326 }
327
328 fn try_new_less_safe(algorithm: Algorithm, value: &[u8]) -> Result<Prk, Unspecified> {
329 let key_len = value.len();
330 if key_len > MAX_HKDF_PRK_LEN {
331 return Err(Unspecified);
332 }
333 let mut key_bytes = [0u8; MAX_HKDF_PRK_LEN];
334 key_bytes[0..key_len].copy_from_slice(value);
335 Ok(Self {
336 algorithm,
337 mode: PrkMode::Expand { key_bytes, key_len },
338 })
339 }
340
341 #[inline]
354 pub fn expand<'a, L: KeyType>(
355 &'a self,
356 info: &'a [&'a [u8]],
357 len: L,
358 ) -> Result<Okm<'a, L>, Unspecified> {
359 let len_cached = len.len();
360 if len_cached > 255 * self.algorithm.0.digest_algorithm().output_len {
361 return Err(Unspecified);
362 }
363 let mut info_bytes: Vec<u8> = Vec::with_capacity(HKDF_INFO_DEFAULT_CAPACITY_LEN);
364 let mut info_len = 0;
365 for &byte_ary in info {
366 info_bytes.extend_from_slice(byte_ary);
367 info_len += byte_ary.len();
368 }
369 let info_bytes = info_bytes.into_boxed_slice();
370 Ok(Okm {
371 prk: self,
372 info_bytes,
373 info_len,
374 len,
375 })
376 }
377}
378
379impl From<Okm<'_, Algorithm>> for Prk {
380 fn from(okm: Okm<Algorithm>) -> Self {
381 let algorithm = okm.len;
382 let key_len = okm.len.len();
383 let mut key_bytes = [0u8; MAX_HKDF_PRK_LEN];
384 okm.fill(&mut key_bytes[0..key_len]).unwrap();
385
386 Self {
387 algorithm,
388 mode: PrkMode::Expand { key_bytes, key_len },
389 }
390 }
391}
392
393pub struct Okm<'a, L: KeyType> {
398 prk: &'a Prk,
399 info_bytes: Box<[u8]>,
400 info_len: usize,
401 len: L,
402}
403
404impl<L: KeyType> fmt::Debug for Okm<'_, L> {
405 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
406 f.debug_struct("hkdf::Okm").field("prk", &self.prk).finish()
407 }
408}
409
410impl<L: KeyType> Drop for Okm<'_, L> {
411 fn drop(&mut self) {
412 self.info_bytes.zeroize();
413 }
414}
415
416impl<L: KeyType> Okm<'_, L> {
417 #[inline]
419 pub fn len(&self) -> &L {
420 &self.len
421 }
422
423 #[inline]
441 pub fn fill(self, out: &mut [u8]) -> Result<(), Unspecified> {
442 if out.len() != self.len.len() {
443 return Err(Unspecified);
444 }
445
446 self.prk
447 .mode
448 .fill(self.prk.algorithm, out, &self.info_bytes[..self.info_len])?;
449
450 Ok(())
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use crate::hkdf::{Salt, HKDF_SHA256, HKDF_SHA384};
457
458 #[cfg(feature = "fips")]
459 mod fips;
460
461 #[test]
462 fn hkdf_coverage() {
463 assert_ne!(HKDF_SHA256, HKDF_SHA384);
466 assert_eq!("Algorithm(Algorithm(SHA256))", format!("{HKDF_SHA256:?}"));
467 }
468
469 #[test]
470 fn test_debug() {
471 const SALT: &[u8; 32] = &[
472 29, 113, 120, 243, 11, 202, 39, 222, 206, 81, 163, 184, 122, 153, 52, 192, 98, 195,
473 240, 32, 34, 19, 160, 128, 178, 111, 97, 232, 113, 101, 221, 143,
474 ];
475 const SECRET1: &[u8; 32] = &[
476 157, 191, 36, 107, 110, 131, 193, 6, 175, 226, 193, 3, 168, 133, 165, 181, 65, 120,
477 194, 152, 31, 92, 37, 191, 73, 222, 41, 112, 207, 236, 196, 174,
478 ];
479
480 const INFO1: &[&[u8]] = &[
481 &[
482 2, 130, 61, 83, 192, 248, 63, 60, 211, 73, 169, 66, 101, 160, 196, 212, 250, 113,
483 ],
484 &[
485 80, 46, 248, 123, 78, 204, 171, 178, 67, 204, 96, 27, 131, 24,
486 ],
487 ];
488
489 let alg = HKDF_SHA256;
490 let salt = Salt::new(alg, SALT);
491 let prk = salt.extract(SECRET1);
492 let okm = prk.expand(INFO1, alg).unwrap();
493
494 assert_eq!(
495 "hkdf::Salt { algorithm: Algorithm(SHA256) }",
496 format!("{salt:?}")
497 );
498 assert_eq!(
499 "hkdf::Prk { algorithm: Algorithm(SHA256), mode: ExtractExpand { .. } }",
500 format!("{prk:?}")
501 );
502 assert_eq!(
503 "hkdf::Okm { prk: hkdf::Prk { algorithm: Algorithm(SHA256), mode: ExtractExpand { .. } } }",
504 format!("{okm:?}")
505 );
506 }
507}