1use std::{io::prelude::*, path::PathBuf};
12
13use base64::prelude::{Engine, BASE64_STANDARD};
14pub use crypto_secretbox::Key;
15use crypto_secretbox::{
16 aead::{Nonce, OsRng},
17 AeadCore, AeadInPlace, KeyInit, XSalsa20Poly1305,
18};
19use eyre::{bail, ensure, eyre, Context, Result};
20use fs_err as fs;
21use rmp::{decode::Bytes, Marker};
22use serde::{Deserialize, Serialize};
23use time::{format_description::well_known::Rfc3339, macros::format_description, OffsetDateTime};
24
25use crate::{history::History, settings::Settings};
26
27#[derive(Debug, Serialize, Deserialize)]
28pub struct EncryptedHistory {
29 pub ciphertext: Vec<u8>,
30 pub nonce: Nonce<XSalsa20Poly1305>,
31}
32
33pub fn generate_encoded_key() -> Result<(Key, String)> {
34 let key = XSalsa20Poly1305::generate_key(&mut OsRng);
35 let encoded = encode_key(&key)?;
36
37 Ok((key, encoded))
38}
39
40pub fn new_key(settings: &Settings) -> Result<Key> {
41 let path = settings.key_path.as_str();
42 let path = PathBuf::from(path);
43
44 if path.exists() {
45 bail!("key already exists! cannot overwrite");
46 }
47
48 let (key, encoded) = generate_encoded_key()?;
49
50 let mut file = fs::File::create(path)?;
51 file.write_all(encoded.as_bytes())?;
52
53 Ok(key)
54}
55
56pub fn load_key(settings: &Settings) -> Result<Key> {
58 let path = settings.key_path.as_str();
59
60 let key = if PathBuf::from(path).exists() {
61 let key = fs_err::read_to_string(path)?;
62 decode_key(key)?
63 } else {
64 new_key(settings)?
65 };
66
67 Ok(key)
68}
69
70pub fn encode_key(key: &Key) -> Result<String> {
71 let mut buf = vec![];
72 rmp::encode::write_array_len(&mut buf, key.len() as u32)
73 .wrap_err("could not encode key to message pack")?;
74 for b in key {
75 rmp::encode::write_uint(&mut buf, *b as u64)
76 .wrap_err("could not encode key to message pack")?;
77 }
78 let buf = BASE64_STANDARD.encode(buf);
79
80 Ok(buf)
81}
82
83pub fn decode_key(key: String) -> Result<Key> {
84 use rmp::decode;
85
86 let buf = BASE64_STANDARD
87 .decode(key.trim_end())
88 .wrap_err("encryption key is not a valid base64 encoding")?;
89
90 match <[u8; 32]>::try_from(&*buf) {
93 Ok(key) => Ok(key.into()),
94 Err(_) => {
95 let mut bytes = rmp::decode::Bytes::new(&buf);
96
97 match Marker::from_u8(buf[0]) {
98 Marker::Bin8 => {
99 let len = decode::read_bin_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?;
100 ensure!(len == 32, "encryption key is not the correct size");
101 let key = <[u8; 32]>::try_from(bytes.remaining_slice())
102 .context("could not decode encryption key")?;
103 Ok(key.into())
104 }
105 Marker::Array16 => {
106 let len = decode::read_array_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?;
107 ensure!(len == 32, "encryption key is not the correct size");
108
109 let mut key = Key::default();
110 for i in &mut key {
111 *i = rmp::decode::read_int(&mut bytes).map_err(|err| eyre!("{err:?}"))?;
112 }
113 Ok(key)
114 }
115 _ => bail!("could not decode encryption key"),
116 }
117 }
118 }
119}
120
121pub fn encrypt(history: &History, key: &Key) -> Result<EncryptedHistory> {
122 let mut buf = encode(history)?;
124
125 let nonce = XSalsa20Poly1305::generate_nonce(&mut OsRng);
126 XSalsa20Poly1305::new(key)
127 .encrypt_in_place(&nonce, &[], &mut buf)
128 .map_err(|_| eyre!("could not encrypt"))?;
129
130 Ok(EncryptedHistory {
131 ciphertext: buf,
132 nonce,
133 })
134}
135
136pub fn decrypt(mut encrypted_history: EncryptedHistory, key: &Key) -> Result<History> {
137 XSalsa20Poly1305::new(key)
138 .decrypt_in_place(
139 &encrypted_history.nonce,
140 &[],
141 &mut encrypted_history.ciphertext,
142 )
143 .map_err(|_| eyre!("could not decrypt history"))?;
144 let plaintext = encrypted_history.ciphertext;
145
146 let history = decode(&plaintext)?;
147
148 Ok(history)
149}
150
151fn format_rfc3339(ts: OffsetDateTime) -> Result<String> {
152 static PARTIAL_RFC3339_0: &[time::format_description::FormatItem<'static>] =
155 format_description!("[year]-[month]-[day]T[hour]:[minute]:[second]Z");
156 static PARTIAL_RFC3339_3: &[time::format_description::FormatItem<'static>] =
157 format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z");
158 static PARTIAL_RFC3339_6: &[time::format_description::FormatItem<'static>] =
159 format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:6]Z");
160 static PARTIAL_RFC3339_9: &[time::format_description::FormatItem<'static>] =
161 format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:9]Z");
162
163 let fmt = match ts.nanosecond() {
164 0 => PARTIAL_RFC3339_0,
165 ns if ns % 1_000_000 == 0 => PARTIAL_RFC3339_3,
166 ns if ns % 1_000 == 0 => PARTIAL_RFC3339_6,
167 _ => PARTIAL_RFC3339_9,
168 };
169
170 Ok(ts.format(fmt)?)
171}
172
173fn encode(h: &History) -> Result<Vec<u8>> {
174 use rmp::encode;
175
176 let mut output = vec![];
177 encode::write_array_len(&mut output, 9)?;
179
180 encode::write_str(&mut output, &h.id.0)?;
181 encode::write_str(&mut output, &(format_rfc3339(h.timestamp)?))?;
182 encode::write_sint(&mut output, h.duration)?;
183 encode::write_sint(&mut output, h.exit)?;
184 encode::write_str(&mut output, &h.command)?;
185 encode::write_str(&mut output, &h.cwd)?;
186 encode::write_str(&mut output, &h.session)?;
187 encode::write_str(&mut output, &h.hostname)?;
188 match h.deleted_at {
189 Some(d) => encode::write_str(&mut output, &format_rfc3339(d)?)?,
190 None => encode::write_nil(&mut output)?,
191 }
192
193 Ok(output)
194}
195
196fn decode(bytes: &[u8]) -> Result<History> {
197 use rmp::decode::{self, DecodeStringError};
198
199 let mut bytes = Bytes::new(bytes);
200
201 let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
202 if nfields < 8 {
203 bail!("malformed decrypted history")
204 }
205 if nfields > 9 {
206 bail!("cannot decrypt history from a newer version of atuin");
207 }
208
209 let bytes = bytes.remaining_slice();
210 let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
211 let (timestamp, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
212
213 let mut bytes = Bytes::new(bytes);
214 let duration = decode::read_int(&mut bytes).map_err(error_report)?;
215 let exit = decode::read_int(&mut bytes).map_err(error_report)?;
216
217 let bytes = bytes.remaining_slice();
218 let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
219 let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
220 let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
221 let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
222
223 let mut deleted_at = None;
225 let mut bytes = bytes;
226 if nfields > 8 {
227 bytes = match decode::read_str_from_slice(bytes) {
228 Ok((d, b)) => {
229 deleted_at = Some(d);
230 b
231 }
232 Err(DecodeStringError::TypeMismatch(Marker::Null)) => {
234 let mut c = Bytes::new(bytes);
236 decode::read_nil(&mut c).map_err(error_report)?;
237 c.remaining_slice()
238 }
239 Err(err) => return Err(error_report(err)),
240 };
241 }
242
243 if !bytes.is_empty() {
244 bail!("trailing bytes in encoded history. malformed")
245 }
246
247 Ok(History {
248 id: id.to_owned().into(),
249 timestamp: OffsetDateTime::parse(timestamp, &Rfc3339)?,
250 duration,
251 exit,
252 command: command.to_owned(),
253 cwd: cwd.to_owned(),
254 session: session.to_owned(),
255 hostname: hostname.to_owned(),
256 deleted_at: deleted_at
257 .map(|t| OffsetDateTime::parse(t, &Rfc3339))
258 .transpose()?,
259 })
260}
261
262fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
263 eyre!("{err:?}")
264}
265
266#[cfg(test)]
267mod test {
268 use crypto_secretbox::{aead::OsRng, KeyInit, XSalsa20Poly1305};
269 use pretty_assertions::assert_eq;
270 use time::{macros::datetime, OffsetDateTime};
271
272 use crate::history::History;
273
274 use super::{decode, decrypt, encode, encrypt};
275
276 #[test]
277 fn test_encrypt_decrypt() {
278 let key1 = XSalsa20Poly1305::generate_key(&mut OsRng);
279 let key2 = XSalsa20Poly1305::generate_key(&mut OsRng);
280
281 let history = History::from_db()
282 .id("1".into())
283 .timestamp(OffsetDateTime::now_utc())
284 .command("ls".into())
285 .cwd("/home/ellie".into())
286 .exit(0)
287 .duration(1)
288 .session("beep boop".into())
289 .hostname("booop".into())
290 .deleted_at(None)
291 .build()
292 .into();
293
294 let e1 = encrypt(&history, &key1).unwrap();
295 let e2 = encrypt(&history, &key2).unwrap();
296
297 assert_ne!(e1.ciphertext, e2.ciphertext);
298 assert_ne!(e1.nonce, e2.nonce);
299
300 match decrypt(e1, &key1) {
303 Err(e) => panic!("failed to decrypt, got {}", e),
304 Ok(h) => assert_eq!(h, history),
305 };
306
307 let _ = decrypt(e2, &key1).expect_err("expected an error decrypting with invalid key");
309 }
310
311 #[test]
312 fn test_decode() {
313 let bytes = [
314 0x99, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56,
315 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45,
316 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90,
317 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42,
318 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97,
319 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97,
320 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52,
321 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102,
322 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46,
323 108, 117, 100, 103, 97, 116, 101, 192,
324 ];
325 let history = History {
326 id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(),
327 timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00),
328 duration: 49206000,
329 exit: 0,
330 command: "git status".to_owned(),
331 cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(),
332 session: "b97d9a306f274473a203d2eba41f9457".to_owned(),
333 hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(),
334 deleted_at: None,
335 };
336
337 let h = decode(&bytes).unwrap();
338 assert_eq!(history, h);
339
340 let b = encode(&h).unwrap();
341 assert_eq!(&bytes, &*b);
342 }
343
344 #[test]
345 fn test_decode_deleted() {
346 let history = History {
347 id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(),
348 timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00),
349 duration: 49206000,
350 exit: 0,
351 command: "git status".to_owned(),
352 cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(),
353 session: "b97d9a306f274473a203d2eba41f9457".to_owned(),
354 hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(),
355 deleted_at: Some(datetime!(2023-05-28 18:35:40.633872 +00:00)),
356 };
357
358 let b = encode(&history).unwrap();
359 let h = decode(&b).unwrap();
360 assert_eq!(history, h);
361 }
362
363 #[test]
364 fn test_decode_old() {
365 let bytes = [
366 0x98, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56,
367 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45,
368 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90,
369 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42,
370 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97,
371 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97,
372 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52,
373 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102,
374 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46,
375 108, 117, 100, 103, 97, 116, 101,
376 ];
377 let history = History {
378 id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(),
379 timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00),
380 duration: 49206000,
381 exit: 0,
382 command: "git status".to_owned(),
383 cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(),
384 session: "b97d9a306f274473a203d2eba41f9457".to_owned(),
385 hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(),
386 deleted_at: None,
387 };
388
389 let h = decode(&bytes).unwrap();
390 assert_eq!(history, h);
391 }
392
393 #[test]
394 fn key_encodings() {
395 use super::{decode_key, encode_key, Key};
396
397 let key = Key::from([
411 27, 91, 42, 91, 210, 107, 9, 216, 170, 190, 242, 62, 6, 84, 69, 148, 148, 53, 251, 117,
412 226, 167, 173, 52, 82, 34, 138, 110, 169, 124, 92, 229,
413 ]);
414
415 assert_eq!(
416 encode_key(&key).unwrap(),
417 "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q=="
418 );
419
420 let valid_encodings = [
422 "xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q==",
423 "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q==",
424 ];
425
426 for k in valid_encodings {
427 assert_eq!(decode_key(k.to_owned()).expect(k), key);
428 }
429 }
430}