1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
//! Counter with CBC-MAC ([CCM]): [Authenticated Encryption and Associated Data (AEAD)][1]
//! algorithm generic over block ciphers with block size equal to 128 bits as specified in
//! [RFC 3610].
//!
//! # Usage
//!
//! Simple usage (allocating, no associated data):
//!
#![cfg_attr(all(feature = "getrandom", feature = "std"), doc = "```")]
#![cfg_attr(not(all(feature = "getrandom", feature = "std")), doc = "```ignore")]
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! use aes::Aes256;
//! use ccm::{
//!     aead::{Aead, KeyInit, OsRng, generic_array::GenericArray},
//!     consts::{U10, U13},
//!     Ccm,
//! };
//!
//! // AES-256-CCM type with tag and nonce size equal to 10 and 13 bytes respectively
//! pub type Aes256Ccm = Ccm<Aes256, U10, U13>;
//!
//! let key = Aes256Ccm::generate_key(&mut OsRng);
//! let cipher = Aes256Ccm::new(&key);
//! let nonce = GenericArray::from_slice(b"unique nonce."); // 13-bytes; unique per message
//! let ciphertext = cipher.encrypt(nonce, b"plaintext message".as_ref())?;
//! let plaintext = cipher.decrypt(nonce, ciphertext.as_ref())?;
//! assert_eq!(&plaintext, b"plaintext message");
//! # Ok(())
//! # }
//! ```
//!
//! This crate implements traits from the [`aead`] crate and is capable to perfrom
//! encryption and decryption in-place wihout relying on `alloc`.
//!
//! [RFC 3610]: https://tools.ietf.org/html/rfc3610
//! [CCM]: https://en.wikipedia.org/wiki/CCM_mode
//! [aead]: https://docs.rs/aead
//! [1]: https://en.wikipedia.org/wiki/Authenticated_encryption

#![no_std]
#![doc(
    html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
    html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
)]
#![deny(unsafe_code)]
#![warn(missing_docs, rust_2018_idioms)]

pub use aead::{self, consts, AeadCore, AeadInPlace, Error, Key, KeyInit, KeySizeUser};

use aead::{
    consts::{U0, U16},
    generic_array::{typenum::Unsigned, ArrayLength, GenericArray},
};
use cipher::{
    Block, BlockCipher, BlockEncrypt, BlockSizeUser, InnerIvInit, StreamCipher, StreamCipherSeek,
};
use core::marker::PhantomData;
use ctr::{Ctr32BE, Ctr64BE, CtrCore};
use subtle::ConstantTimeEq;

mod private;

/// CCM nonces
pub type Nonce<NonceSize> = GenericArray<u8, NonceSize>;

/// CCM tags
pub type Tag<TagSize> = GenericArray<u8, TagSize>;

/// Trait implemented for valid tag sizes, i.e.
/// [`U4`][consts::U4], [`U6`][consts::U6], [`U8`][consts::U8],
/// [`U10`][consts::U10], [`U12`][consts::U12], [`U14`][consts::U14], and
/// [`U16`][consts::U16].
pub trait TagSize: private::SealedTag {}

impl<T: private::SealedTag> TagSize for T {}

/// Trait implemented for valid nonce sizes, i.e.
/// [`U7`][consts::U7], [`U8`][consts::U8], [`U9`][consts::U9],
/// [`U10`][consts::U10], [`U11`][consts::U11], [`U12`][consts::U12], and
/// [`U13`][consts::U13].
pub trait NonceSize: private::SealedNonce {}

impl<T: private::SealedNonce> NonceSize for T {}

/// CCM instance generic over an underlying block cipher.
///
/// Type parameters:
/// - `C`: block cipher.
/// - `M`: size of MAC tag in bytes, valid values:
/// [`U4`][consts::U4], [`U6`][consts::U6], [`U8`][consts::U8],
/// [`U10`][consts::U10], [`U12`][consts::U12], [`U14`][consts::U14],
/// [`U12`][consts::U12].
/// - `N`: size of nonce, valid values:
/// [`U7`][consts::U7], [`U8`][consts::U8], [`U9`][consts::U9],
/// [`U10`][consts::U10], [`U11`][consts::U11], [`U12`][consts::U12],
/// [`U13`][consts::U13].
#[derive(Clone)]
pub struct Ccm<C, M, N>
where
    C: BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt,
    M: ArrayLength<u8> + TagSize,
    N: ArrayLength<u8> + NonceSize,
{
    cipher: C,
    _pd: PhantomData<(M, N)>,
}

impl<C, M, N> Ccm<C, M, N>
where
    C: BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt,
    M: ArrayLength<u8> + TagSize,
    N: ArrayLength<u8> + NonceSize,
{
    fn extend_nonce(nonce: &Nonce<N>) -> Block<C> {
        let mut ext_nonce = Block::<C>::default();
        ext_nonce[0] = N::get_l() - 1;
        ext_nonce[1..][..nonce.len()].copy_from_slice(nonce);
        ext_nonce
    }

    fn calc_mac(
        &self,
        nonce: &Nonce<N>,
        adata: &[u8],
        buffer: &[u8],
    ) -> Result<Tag<C::BlockSize>, Error> {
        let is_ad = !adata.is_empty();
        let l = N::get_l();
        let flags = 64 * (is_ad as u8) + 8 * M::get_m_tick() + (l - 1);

        if buffer.len() > N::get_max_len() {
            return Err(Error);
        }

        let mut b0 = Block::<C>::default();
        b0[0] = flags;
        let n = 1 + N::to_usize();
        b0[1..n].copy_from_slice(nonce);

        let cb = b0.len() - n;
        // the max len check makes certain that we discard only
        // zero bytes from `b`
        if cb > 4 {
            let b = (buffer.len() as u64).to_be_bytes();
            b0[n..].copy_from_slice(&b[b.len() - cb..]);
        } else {
            let b = (buffer.len() as u32).to_be_bytes();
            b0[n..].copy_from_slice(&b[b.len() - cb..]);
        }

        let mut mac = CbcMac::from_cipher(&self.cipher);
        mac.block_update(&b0);

        if !adata.is_empty() {
            let alen = adata.len();
            let (n, mut b) = fill_aad_header(alen);
            if b.len() - n >= alen {
                b[n..][..alen].copy_from_slice(adata);
                mac.block_update(&b);
            } else {
                let (l, r) = adata.split_at(b.len() - n);
                b[n..].copy_from_slice(l);
                mac.block_update(&b);
                mac.update(r);
            }
        }

        mac.update(buffer);

        Ok(mac.finalize())
    }
}

impl<C, M, N> From<C> for Ccm<C, M, N>
where
    C: BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt,
    M: ArrayLength<u8> + TagSize,
    N: ArrayLength<u8> + NonceSize,
{
    fn from(cipher: C) -> Self {
        Self {
            cipher,
            _pd: PhantomData,
        }
    }
}

impl<C, M, N> KeySizeUser for Ccm<C, M, N>
where
    C: BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt + KeyInit,
    M: ArrayLength<u8> + TagSize,
    N: ArrayLength<u8> + NonceSize,
{
    type KeySize = C::KeySize;
}

impl<C, M, N> KeyInit for Ccm<C, M, N>
where
    C: BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt + KeyInit,
    M: ArrayLength<u8> + TagSize,
    N: ArrayLength<u8> + NonceSize,
{
    fn new(key: &Key<Self>) -> Self {
        Self::from(C::new(key))
    }
}

impl<C, M, N> AeadCore for Ccm<C, M, N>
where
    C: BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt,
    M: ArrayLength<u8> + TagSize,
    N: ArrayLength<u8> + NonceSize,
{
    type NonceSize = N;
    type TagSize = M;
    type CiphertextOverhead = U0;
}

impl<C, M, N> AeadInPlace for Ccm<C, M, N>
where
    C: BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt,
    M: ArrayLength<u8> + TagSize,
    N: ArrayLength<u8> + NonceSize,
{
    fn encrypt_in_place_detached(
        &self,
        nonce: &Nonce<N>,
        adata: &[u8],
        buffer: &mut [u8],
    ) -> Result<Tag<Self::TagSize>, Error> {
        let mut full_tag = self.calc_mac(nonce, adata, buffer)?;

        let ext_nonce = Self::extend_nonce(nonce);
        // number of bytes left for counter (max 8)
        let cb = C::BlockSize::USIZE - N::USIZE - 1;

        if cb > 4 {
            let mut ctr = Ctr64BE::from_core(CtrCore::inner_iv_init(&self.cipher, &ext_nonce));
            ctr.apply_keystream(&mut full_tag);
            ctr.apply_keystream(buffer);
        } else {
            let mut ctr = Ctr32BE::from_core(CtrCore::inner_iv_init(&self.cipher, &ext_nonce));
            ctr.apply_keystream(&mut full_tag);
            ctr.apply_keystream(buffer);
        }

        Ok(Tag::clone_from_slice(&full_tag[..M::to_usize()]))
    }

    fn decrypt_in_place_detached(
        &self,
        nonce: &Nonce<N>,
        adata: &[u8],
        buffer: &mut [u8],
        tag: &Tag<Self::TagSize>,
    ) -> Result<(), Error> {
        let ext_nonce = Self::extend_nonce(nonce);
        // number of bytes left for counter (max 8)
        let cb = C::BlockSize::USIZE - N::USIZE - 1;

        if cb > 4 {
            let mut ctr = Ctr64BE::from_core(CtrCore::inner_iv_init(&self.cipher, &ext_nonce));
            ctr.seek(C::BlockSize::USIZE);
            ctr.apply_keystream(buffer);
        } else {
            let mut ctr = Ctr32BE::from_core(CtrCore::inner_iv_init(&self.cipher, &ext_nonce));
            ctr.seek(C::BlockSize::USIZE);
            ctr.apply_keystream(buffer);
        }

        let mut full_tag = self.calc_mac(nonce, adata, buffer)?;

        if cb > 4 {
            let mut ctr = Ctr64BE::from_core(CtrCore::inner_iv_init(&self.cipher, &ext_nonce));
            ctr.apply_keystream(&mut full_tag);
        } else {
            let mut ctr = Ctr32BE::from_core(CtrCore::inner_iv_init(&self.cipher, &ext_nonce));
            ctr.apply_keystream(&mut full_tag);
        }

        if full_tag[..tag.len()].ct_eq(tag).into() {
            Ok(())
        } else {
            buffer.iter_mut().for_each(|v| *v = 0);
            Err(Error)
        }
    }
}

struct CbcMac<'a, C: BlockCipher + BlockEncrypt> {
    cipher: &'a C,
    state: Block<C>,
}

impl<'a, C> CbcMac<'a, C>
where
    C: BlockCipher + BlockEncrypt,
{
    fn from_cipher(cipher: &'a C) -> Self {
        Self {
            cipher,
            state: Default::default(),
        }
    }

    fn update(&mut self, data: &[u8]) {
        let mut chunks = data.chunks_exact(C::BlockSize::USIZE);
        for chunk in &mut chunks {
            self.block_update(Block::<C>::from_slice(chunk));
        }
        let rem = chunks.remainder();
        if !rem.is_empty() {
            let mut bn = Block::<C>::default();
            bn[..rem.len()].copy_from_slice(rem);
            self.block_update(&bn);
        }
    }

    fn block_update(&mut self, block: &Block<C>) {
        self.state
            .iter_mut()
            .zip(block.iter())
            .for_each(|(a, b)| *a ^= b);
        self.cipher.encrypt_block(&mut self.state);
    }

    fn finalize(self) -> Block<C> {
        self.state
    }
}

fn fill_aad_header(adata_len: usize) -> (usize, GenericArray<u8, U16>) {
    debug_assert_ne!(adata_len, 0);

    let mut b = GenericArray::<u8, U16>::default();
    let n = if adata_len < 0xFF00 {
        b[..2].copy_from_slice(&(adata_len as u16).to_be_bytes());
        2
    } else if adata_len <= core::u32::MAX as usize {
        b[0] = 0xFF;
        b[1] = 0xFE;
        b[2..6].copy_from_slice(&(adata_len as u32).to_be_bytes());
        6
    } else {
        b[0] = 0xFF;
        b[1] = 0xFF;
        b[2..10].copy_from_slice(&(adata_len as u64).to_be_bytes());
        10
    };
    (n, b)
}

#[cfg(test)]
mod tests {
    #[test]
    fn fill_aad_header_test() {
        use super::fill_aad_header;
        use hex_literal::hex;

        let (n, b) = fill_aad_header(0x0123);
        assert_eq!(n, 2);
        assert_eq!(b[..], hex!("01230000000000000000000000000000")[..]);

        let (n, b) = fill_aad_header(0xFF00);
        assert_eq!(n, 6);
        assert_eq!(b[..], hex!("FFFE0000FF0000000000000000000000")[..]);

        let (n, b) = fill_aad_header(0x01234567);
        assert_eq!(n, 6);
        assert_eq!(b[..], hex!("FFFE0123456700000000000000000000")[..]);

        #[cfg(target_pointer_width = "64")]
        {
            let (n, b) = fill_aad_header(0x0123456789ABCDEF);
            assert_eq!(n, 10);
            assert_eq!(b[..], hex!("FFFF0123456789ABCDEF000000000000")[..]);
        }
    }
}