atuin_client/record/
sqlite_store.rs

1// Here we are using sqlite as a pretty dumb store, and will not be running any complex queries.
2// Multiple stores of multiple types are all stored in one chonky table (for now), and we just index
3// by tag/host
4
5use std::str::FromStr;
6use std::{path::Path, time::Duration};
7
8use async_trait::async_trait;
9use eyre::{eyre, Result};
10use fs_err as fs;
11
12use sqlx::{
13    sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
14    Row,
15};
16
17use atuin_common::record::{
18    EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus,
19};
20use uuid::Uuid;
21
22use super::encryption::PASETO_V4;
23use super::store::Store;
24
25#[derive(Debug, Clone)]
26pub struct SqliteStore {
27    pool: SqlitePool,
28}
29
30impl SqliteStore {
31    pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
32        let path = path.as_ref();
33
34        debug!("opening sqlite database at {:?}", path);
35
36        let create = !path.exists();
37        if create {
38            if let Some(dir) = path.parent() {
39                fs::create_dir_all(dir)?;
40            }
41        }
42
43        let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
44            .journal_mode(SqliteJournalMode::Wal)
45            .foreign_keys(true)
46            .create_if_missing(true);
47
48        let pool = SqlitePoolOptions::new()
49            .acquire_timeout(Duration::from_secs_f64(timeout))
50            .connect_with(opts)
51            .await?;
52
53        Self::setup_db(&pool).await?;
54
55        Ok(Self { pool })
56    }
57
58    async fn setup_db(pool: &SqlitePool) -> Result<()> {
59        debug!("running sqlite database setup");
60
61        sqlx::migrate!("./record-migrations").run(pool).await?;
62
63        Ok(())
64    }
65
66    async fn save_raw(
67        tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
68        r: &Record<EncryptedData>,
69    ) -> Result<()> {
70        // In sqlite, we are "limited" to i64. But that is still fine, until 2262.
71        sqlx::query(
72            "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek)
73                values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
74        )
75        .bind(r.id.0.as_hyphenated().to_string())
76        .bind(r.idx as i64)
77        .bind(r.host.id.0.as_hyphenated().to_string())
78        .bind(r.tag.as_str())
79        .bind(r.timestamp as i64)
80        .bind(r.version.as_str())
81        .bind(r.data.data.as_str())
82        .bind(r.data.content_encryption_key.as_str())
83        .execute(&mut **tx)
84        .await?;
85
86        Ok(())
87    }
88
89    fn query_row(row: SqliteRow) -> Record<EncryptedData> {
90        let idx: i64 = row.get("idx");
91        let timestamp: i64 = row.get("timestamp");
92
93        // tbh at this point things are pretty fucked so just panic
94        let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
95        let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
96
97        Record {
98            id: RecordId(id),
99            idx: idx as u64,
100            host: Host::new(HostId(host)),
101            timestamp: timestamp as u64,
102            tag: row.get("tag"),
103            version: row.get("version"),
104            data: EncryptedData {
105                data: row.get("data"),
106                content_encryption_key: row.get("cek"),
107            },
108        }
109    }
110
111    async fn load_all(&self) -> Result<Vec<Record<EncryptedData>>> {
112        let res = sqlx::query("select * from store ")
113            .map(Self::query_row)
114            .fetch_all(&self.pool)
115            .await?;
116
117        Ok(res)
118    }
119}
120
121#[async_trait]
122impl Store for SqliteStore {
123    async fn push_batch(
124        &self,
125        records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync,
126    ) -> Result<()> {
127        let mut tx = self.pool.begin().await?;
128
129        for record in records {
130            Self::save_raw(&mut tx, record).await?;
131        }
132
133        tx.commit().await?;
134
135        Ok(())
136    }
137
138    async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
139        let res = sqlx::query("select * from store where store.id = ?1")
140            .bind(id.0.as_hyphenated().to_string())
141            .map(Self::query_row)
142            .fetch_one(&self.pool)
143            .await?;
144
145        Ok(res)
146    }
147
148    async fn delete(&self, id: RecordId) -> Result<()> {
149        sqlx::query("delete from store where id = ?1")
150            .bind(id.0.as_hyphenated().to_string())
151            .execute(&self.pool)
152            .await?;
153
154        Ok(())
155    }
156
157    async fn delete_all(&self) -> Result<()> {
158        sqlx::query("delete from store").execute(&self.pool).await?;
159
160        Ok(())
161    }
162
163    async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
164        let res =
165            sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1")
166                .bind(host.0.as_hyphenated().to_string())
167                .bind(tag)
168                .map(Self::query_row)
169                .fetch_one(&self.pool)
170                .await;
171
172        match res {
173            Err(sqlx::Error::RowNotFound) => Ok(None),
174            Err(e) => Err(eyre!("an error occurred: {}", e)),
175            Ok(record) => Ok(Some(record)),
176        }
177    }
178
179    async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
180        self.idx(host, tag, 0).await
181    }
182
183    async fn len_all(&self) -> Result<u64> {
184        let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select count(*) from store")
185            .fetch_one(&self.pool)
186            .await;
187        match res {
188            Err(e) => Err(eyre!("failed to fetch local store len: {}", e)),
189            Ok(v) => Ok(v.0 as u64),
190        }
191    }
192
193    async fn len_tag(&self, tag: &str) -> Result<u64> {
194        let res: Result<(i64,), sqlx::Error> =
195            sqlx::query_as("select count(*) from store where tag=?1")
196                .bind(tag)
197                .fetch_one(&self.pool)
198                .await;
199        match res {
200            Err(e) => Err(eyre!("failed to fetch local store len: {}", e)),
201            Ok(v) => Ok(v.0 as u64),
202        }
203    }
204
205    async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
206        let last = self.last(host, tag).await?;
207
208        if let Some(last) = last {
209            return Ok(last.idx + 1);
210        }
211
212        return Ok(0);
213    }
214
215    async fn next(
216        &self,
217        host: HostId,
218        tag: &str,
219        idx: RecordIdx,
220        limit: u64,
221    ) -> Result<Vec<Record<EncryptedData>>> {
222        let res = sqlx::query(
223            "select * from store where idx >= ?1 and host = ?2 and tag = ?3 order by idx asc limit ?4",
224        )
225        .bind(idx as i64)
226        .bind(host.0.as_hyphenated().to_string())
227        .bind(tag)
228        .bind(limit as i64)
229        .map(Self::query_row)
230        .fetch_all(&self.pool)
231        .await?;
232
233        Ok(res)
234    }
235
236    async fn idx(
237        &self,
238        host: HostId,
239        tag: &str,
240        idx: RecordIdx,
241    ) -> Result<Option<Record<EncryptedData>>> {
242        let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3")
243            .bind(idx as i64)
244            .bind(host.0.as_hyphenated().to_string())
245            .bind(tag)
246            .map(Self::query_row)
247            .fetch_one(&self.pool)
248            .await;
249
250        match res {
251            Err(sqlx::Error::RowNotFound) => Ok(None),
252            Err(e) => Err(eyre!("an error occurred: {}", e)),
253            Ok(v) => Ok(Some(v)),
254        }
255    }
256
257    async fn status(&self) -> Result<RecordStatus> {
258        let mut status = RecordStatus::new();
259
260        let res: Result<Vec<(String, String, i64)>, sqlx::Error> =
261            sqlx::query_as("select host, tag, max(idx) from store group by host, tag")
262                .fetch_all(&self.pool)
263                .await;
264
265        let res = match res {
266            Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)),
267            Ok(v) => v,
268        };
269
270        for i in res {
271            let host = HostId(
272                Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"),
273            );
274
275            status.set_raw(host, i.1, i.2 as u64);
276        }
277
278        Ok(status)
279    }
280
281    async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
282        let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc")
283            .bind(tag)
284            .map(Self::query_row)
285            .fetch_all(&self.pool)
286            .await?;
287
288        Ok(res)
289    }
290
291    /// Reencrypt every single item in this store with a new key
292    /// Be careful - this may mess with sync.
293    async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()> {
294        // Load all the records
295        // In memory like some of the other code here
296        // This will never be called in a hot loop, and only under the following circumstances
297        // 1. The user has logged into a new account, with a new key. They are unlikely to have a
298        //    lot of data
299        // 2. The user has encountered some sort of issue, and runs a maintenance command that
300        //    invokes this
301        let all = self.load_all().await?;
302
303        let re_encrypted = all
304            .into_iter()
305            .map(|record| record.re_encrypt::<PASETO_V4>(old_key, new_key))
306            .collect::<Result<Vec<_>>>()?;
307
308        // next up, we delete all the old data and reinsert the new stuff
309        // do it in one transaction, so if anything fails we rollback OK
310
311        let mut tx = self.pool.begin().await?;
312
313        let res = sqlx::query("delete from store").execute(&mut *tx).await?;
314
315        let rows = res.rows_affected();
316        debug!("deleted {rows} rows");
317
318        // don't call push_batch, as it will start its own transaction
319        // call the underlying save_raw
320
321        for record in re_encrypted {
322            Self::save_raw(&mut tx, &record).await?;
323        }
324
325        tx.commit().await?;
326
327        Ok(())
328    }
329
330    /// Verify that every record in this store can be decrypted with the current key
331    /// Someday maybe also check each tag/record can be deserialized, but not for now.
332    async fn verify(&self, key: &[u8; 32]) -> Result<()> {
333        let all = self.load_all().await?;
334
335        all.into_iter()
336            .map(|record| record.decrypt::<PASETO_V4>(key))
337            .collect::<Result<Vec<_>>>()?;
338
339        Ok(())
340    }
341
342    /// Verify that every record in this store can be decrypted with the current key
343    /// Someday maybe also check each tag/record can be deserialized, but not for now.
344    async fn purge(&self, key: &[u8; 32]) -> Result<()> {
345        let all = self.load_all().await?;
346
347        for record in all.iter() {
348            match record.clone().decrypt::<PASETO_V4>(key) {
349                Ok(_) => continue,
350                Err(_) => {
351                    println!(
352                        "Failed to decrypt {}, deleting",
353                        record.id.0.as_hyphenated()
354                    );
355
356                    self.delete(record.id).await?;
357                }
358            }
359        }
360
361        Ok(())
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use atuin_common::{
368        record::{DecryptedData, EncryptedData, Host, HostId, Record},
369        utils::uuid_v7,
370    };
371
372    use crate::{
373        encryption::generate_encoded_key,
374        record::{encryption::PASETO_V4, store::Store},
375        settings::test_local_timeout,
376    };
377
378    use super::SqliteStore;
379
380    fn test_record() -> Record<EncryptedData> {
381        Record::builder()
382            .host(Host::new(HostId(atuin_common::utils::uuid_v7())))
383            .version("v1".into())
384            .tag(atuin_common::utils::uuid_v7().simple().to_string())
385            .data(EncryptedData {
386                data: "1234".into(),
387                content_encryption_key: "1234".into(),
388            })
389            .idx(0)
390            .build()
391    }
392
393    #[tokio::test]
394    async fn create_db() {
395        let db = SqliteStore::new(":memory:", test_local_timeout()).await;
396
397        assert!(
398            db.is_ok(),
399            "db could not be created, {:?}",
400            db.err().unwrap()
401        );
402    }
403
404    #[tokio::test]
405    async fn push_record() {
406        let db = SqliteStore::new(":memory:", test_local_timeout())
407            .await
408            .unwrap();
409        let record = test_record();
410
411        db.push(&record).await.expect("failed to insert record");
412    }
413
414    #[tokio::test]
415    async fn get_record() {
416        let db = SqliteStore::new(":memory:", test_local_timeout())
417            .await
418            .unwrap();
419        let record = test_record();
420        db.push(&record).await.unwrap();
421
422        let new_record = db.get(record.id).await.expect("failed to fetch record");
423
424        assert_eq!(record, new_record, "records are not equal");
425    }
426
427    #[tokio::test]
428    async fn last() {
429        let db = SqliteStore::new(":memory:", test_local_timeout())
430            .await
431            .unwrap();
432        let record = test_record();
433        db.push(&record).await.unwrap();
434
435        let last = db
436            .last(record.host.id, record.tag.as_str())
437            .await
438            .expect("failed to get store len");
439
440        assert_eq!(
441            last.unwrap().id,
442            record.id,
443            "expected to get back the same record that was inserted"
444        );
445    }
446
447    #[tokio::test]
448    async fn first() {
449        let db = SqliteStore::new(":memory:", test_local_timeout())
450            .await
451            .unwrap();
452        let record = test_record();
453        db.push(&record).await.unwrap();
454
455        let first = db
456            .first(record.host.id, record.tag.as_str())
457            .await
458            .expect("failed to get store len");
459
460        assert_eq!(
461            first.unwrap().id,
462            record.id,
463            "expected to get back the same record that was inserted"
464        );
465    }
466
467    #[tokio::test]
468    async fn len() {
469        let db = SqliteStore::new(":memory:", test_local_timeout())
470            .await
471            .unwrap();
472        let record = test_record();
473        db.push(&record).await.unwrap();
474
475        let len = db
476            .len(record.host.id, record.tag.as_str())
477            .await
478            .expect("failed to get store len");
479
480        assert_eq!(len, 1, "expected length of 1 after insert");
481    }
482
483    #[tokio::test]
484    async fn len_tag() {
485        let db = SqliteStore::new(":memory:", test_local_timeout())
486            .await
487            .unwrap();
488        let record = test_record();
489        db.push(&record).await.unwrap();
490
491        let len = db
492            .len_tag(record.tag.as_str())
493            .await
494            .expect("failed to get store len");
495
496        assert_eq!(len, 1, "expected length of 1 after insert");
497    }
498
499    #[tokio::test]
500    async fn len_different_tags() {
501        let db = SqliteStore::new(":memory:", test_local_timeout())
502            .await
503            .unwrap();
504
505        // these have different tags, so the len should be the same
506        // we model multiple stores within one database
507        // new store = new tag = independent length
508        let first = test_record();
509        let second = test_record();
510
511        db.push(&first).await.unwrap();
512        db.push(&second).await.unwrap();
513
514        let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap();
515        let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap();
516
517        assert_eq!(first_len, 1, "expected length of 1 after insert");
518        assert_eq!(second_len, 1, "expected length of 1 after insert");
519    }
520
521    #[tokio::test]
522    async fn append_a_bunch() {
523        let db = SqliteStore::new(":memory:", test_local_timeout())
524            .await
525            .unwrap();
526
527        let mut tail = test_record();
528        db.push(&tail).await.expect("failed to push record");
529
530        for _ in 1..100 {
531            tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]);
532            db.push(&tail).await.unwrap();
533        }
534
535        assert_eq!(
536            db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
537            100,
538            "failed to insert 100 records"
539        );
540
541        assert_eq!(
542            db.len_tag(tail.tag.as_str()).await.unwrap(),
543            100,
544            "failed to insert 100 records"
545        );
546    }
547
548    #[tokio::test]
549    async fn append_a_big_bunch() {
550        let db = SqliteStore::new(":memory:", test_local_timeout())
551            .await
552            .unwrap();
553
554        let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(10000);
555
556        let mut tail = test_record();
557        records.push(tail.clone());
558
559        for _ in 1..10000 {
560            tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
561            records.push(tail.clone());
562        }
563
564        db.push_batch(records.iter()).await.unwrap();
565
566        assert_eq!(
567            db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
568            10000,
569            "failed to insert 10k records"
570        );
571    }
572
573    #[tokio::test]
574    async fn re_encrypt() {
575        let store = SqliteStore::new(":memory:", test_local_timeout())
576            .await
577            .unwrap();
578        let (key, _) = generate_encoded_key().unwrap();
579        let data = vec![0u8, 1u8, 2u8, 3u8];
580        let host_id = HostId(uuid_v7());
581
582        for i in 0..10 {
583            let record = Record::builder()
584                .host(Host::new(host_id))
585                .version(String::from("test"))
586                .tag(String::from("test"))
587                .idx(i)
588                .data(DecryptedData(data.clone()))
589                .build();
590
591            let record = record.encrypt::<PASETO_V4>(&key.into());
592            store
593                .push(&record)
594                .await
595                .expect("failed to push encrypted record");
596        }
597
598        // first, check that we can decrypt the data with the current key
599        let all = store.all_tagged("test").await.unwrap();
600
601        assert_eq!(all.len(), 10, "failed to fetch all records");
602
603        for record in all {
604            let decrypted = record.decrypt::<PASETO_V4>(&key.into()).unwrap();
605            assert_eq!(decrypted.data.0, data);
606        }
607
608        // reencrypt the store, then check if
609        // 1) it cannot be decrypted with the old key
610        // 2) it can be decrypted with the new key
611
612        let (new_key, _) = generate_encoded_key().unwrap();
613        store
614            .re_encrypt(&key.into(), &new_key.into())
615            .await
616            .expect("failed to re-encrypt store");
617
618        let all = store.all_tagged("test").await.unwrap();
619
620        for record in all.iter() {
621            let decrypted = record.clone().decrypt::<PASETO_V4>(&key.into());
622            assert!(
623                decrypted.is_err(),
624                "did not get error decrypting with old key after re-encrypt"
625            )
626        }
627
628        for record in all {
629            let decrypted = record.decrypt::<PASETO_V4>(&new_key.into()).unwrap();
630            assert_eq!(decrypted.data.0, data);
631        }
632
633        assert_eq!(store.len(host_id, "test").await.unwrap(), 10);
634    }
635}