1use std::str::FromStr;
6use std::{path::Path, time::Duration};
7
8use async_trait::async_trait;
9use eyre::{eyre, Result};
10use fs_err as fs;
11
12use sqlx::{
13 sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
14 Row,
15};
16
17use atuin_common::record::{
18 EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus,
19};
20use uuid::Uuid;
21
22use super::encryption::PASETO_V4;
23use super::store::Store;
24
25#[derive(Debug, Clone)]
26pub struct SqliteStore {
27 pool: SqlitePool,
28}
29
30impl SqliteStore {
31 pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
32 let path = path.as_ref();
33
34 debug!("opening sqlite database at {:?}", path);
35
36 let create = !path.exists();
37 if create {
38 if let Some(dir) = path.parent() {
39 fs::create_dir_all(dir)?;
40 }
41 }
42
43 let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
44 .journal_mode(SqliteJournalMode::Wal)
45 .foreign_keys(true)
46 .create_if_missing(true);
47
48 let pool = SqlitePoolOptions::new()
49 .acquire_timeout(Duration::from_secs_f64(timeout))
50 .connect_with(opts)
51 .await?;
52
53 Self::setup_db(&pool).await?;
54
55 Ok(Self { pool })
56 }
57
58 async fn setup_db(pool: &SqlitePool) -> Result<()> {
59 debug!("running sqlite database setup");
60
61 sqlx::migrate!("./record-migrations").run(pool).await?;
62
63 Ok(())
64 }
65
66 async fn save_raw(
67 tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
68 r: &Record<EncryptedData>,
69 ) -> Result<()> {
70 sqlx::query(
72 "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek)
73 values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
74 )
75 .bind(r.id.0.as_hyphenated().to_string())
76 .bind(r.idx as i64)
77 .bind(r.host.id.0.as_hyphenated().to_string())
78 .bind(r.tag.as_str())
79 .bind(r.timestamp as i64)
80 .bind(r.version.as_str())
81 .bind(r.data.data.as_str())
82 .bind(r.data.content_encryption_key.as_str())
83 .execute(&mut **tx)
84 .await?;
85
86 Ok(())
87 }
88
89 fn query_row(row: SqliteRow) -> Record<EncryptedData> {
90 let idx: i64 = row.get("idx");
91 let timestamp: i64 = row.get("timestamp");
92
93 let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
95 let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
96
97 Record {
98 id: RecordId(id),
99 idx: idx as u64,
100 host: Host::new(HostId(host)),
101 timestamp: timestamp as u64,
102 tag: row.get("tag"),
103 version: row.get("version"),
104 data: EncryptedData {
105 data: row.get("data"),
106 content_encryption_key: row.get("cek"),
107 },
108 }
109 }
110
111 async fn load_all(&self) -> Result<Vec<Record<EncryptedData>>> {
112 let res = sqlx::query("select * from store ")
113 .map(Self::query_row)
114 .fetch_all(&self.pool)
115 .await?;
116
117 Ok(res)
118 }
119}
120
121#[async_trait]
122impl Store for SqliteStore {
123 async fn push_batch(
124 &self,
125 records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync,
126 ) -> Result<()> {
127 let mut tx = self.pool.begin().await?;
128
129 for record in records {
130 Self::save_raw(&mut tx, record).await?;
131 }
132
133 tx.commit().await?;
134
135 Ok(())
136 }
137
138 async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
139 let res = sqlx::query("select * from store where store.id = ?1")
140 .bind(id.0.as_hyphenated().to_string())
141 .map(Self::query_row)
142 .fetch_one(&self.pool)
143 .await?;
144
145 Ok(res)
146 }
147
148 async fn delete(&self, id: RecordId) -> Result<()> {
149 sqlx::query("delete from store where id = ?1")
150 .bind(id.0.as_hyphenated().to_string())
151 .execute(&self.pool)
152 .await?;
153
154 Ok(())
155 }
156
157 async fn delete_all(&self) -> Result<()> {
158 sqlx::query("delete from store").execute(&self.pool).await?;
159
160 Ok(())
161 }
162
163 async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
164 let res =
165 sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1")
166 .bind(host.0.as_hyphenated().to_string())
167 .bind(tag)
168 .map(Self::query_row)
169 .fetch_one(&self.pool)
170 .await;
171
172 match res {
173 Err(sqlx::Error::RowNotFound) => Ok(None),
174 Err(e) => Err(eyre!("an error occurred: {}", e)),
175 Ok(record) => Ok(Some(record)),
176 }
177 }
178
179 async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
180 self.idx(host, tag, 0).await
181 }
182
183 async fn len_all(&self) -> Result<u64> {
184 let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select count(*) from store")
185 .fetch_one(&self.pool)
186 .await;
187 match res {
188 Err(e) => Err(eyre!("failed to fetch local store len: {}", e)),
189 Ok(v) => Ok(v.0 as u64),
190 }
191 }
192
193 async fn len_tag(&self, tag: &str) -> Result<u64> {
194 let res: Result<(i64,), sqlx::Error> =
195 sqlx::query_as("select count(*) from store where tag=?1")
196 .bind(tag)
197 .fetch_one(&self.pool)
198 .await;
199 match res {
200 Err(e) => Err(eyre!("failed to fetch local store len: {}", e)),
201 Ok(v) => Ok(v.0 as u64),
202 }
203 }
204
205 async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
206 let last = self.last(host, tag).await?;
207
208 if let Some(last) = last {
209 return Ok(last.idx + 1);
210 }
211
212 return Ok(0);
213 }
214
215 async fn next(
216 &self,
217 host: HostId,
218 tag: &str,
219 idx: RecordIdx,
220 limit: u64,
221 ) -> Result<Vec<Record<EncryptedData>>> {
222 let res = sqlx::query(
223 "select * from store where idx >= ?1 and host = ?2 and tag = ?3 order by idx asc limit ?4",
224 )
225 .bind(idx as i64)
226 .bind(host.0.as_hyphenated().to_string())
227 .bind(tag)
228 .bind(limit as i64)
229 .map(Self::query_row)
230 .fetch_all(&self.pool)
231 .await?;
232
233 Ok(res)
234 }
235
236 async fn idx(
237 &self,
238 host: HostId,
239 tag: &str,
240 idx: RecordIdx,
241 ) -> Result<Option<Record<EncryptedData>>> {
242 let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3")
243 .bind(idx as i64)
244 .bind(host.0.as_hyphenated().to_string())
245 .bind(tag)
246 .map(Self::query_row)
247 .fetch_one(&self.pool)
248 .await;
249
250 match res {
251 Err(sqlx::Error::RowNotFound) => Ok(None),
252 Err(e) => Err(eyre!("an error occurred: {}", e)),
253 Ok(v) => Ok(Some(v)),
254 }
255 }
256
257 async fn status(&self) -> Result<RecordStatus> {
258 let mut status = RecordStatus::new();
259
260 let res: Result<Vec<(String, String, i64)>, sqlx::Error> =
261 sqlx::query_as("select host, tag, max(idx) from store group by host, tag")
262 .fetch_all(&self.pool)
263 .await;
264
265 let res = match res {
266 Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)),
267 Ok(v) => v,
268 };
269
270 for i in res {
271 let host = HostId(
272 Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"),
273 );
274
275 status.set_raw(host, i.1, i.2 as u64);
276 }
277
278 Ok(status)
279 }
280
281 async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
282 let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc")
283 .bind(tag)
284 .map(Self::query_row)
285 .fetch_all(&self.pool)
286 .await?;
287
288 Ok(res)
289 }
290
291 async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()> {
294 let all = self.load_all().await?;
302
303 let re_encrypted = all
304 .into_iter()
305 .map(|record| record.re_encrypt::<PASETO_V4>(old_key, new_key))
306 .collect::<Result<Vec<_>>>()?;
307
308 let mut tx = self.pool.begin().await?;
312
313 let res = sqlx::query("delete from store").execute(&mut *tx).await?;
314
315 let rows = res.rows_affected();
316 debug!("deleted {rows} rows");
317
318 for record in re_encrypted {
322 Self::save_raw(&mut tx, &record).await?;
323 }
324
325 tx.commit().await?;
326
327 Ok(())
328 }
329
330 async fn verify(&self, key: &[u8; 32]) -> Result<()> {
333 let all = self.load_all().await?;
334
335 all.into_iter()
336 .map(|record| record.decrypt::<PASETO_V4>(key))
337 .collect::<Result<Vec<_>>>()?;
338
339 Ok(())
340 }
341
342 async fn purge(&self, key: &[u8; 32]) -> Result<()> {
345 let all = self.load_all().await?;
346
347 for record in all.iter() {
348 match record.clone().decrypt::<PASETO_V4>(key) {
349 Ok(_) => continue,
350 Err(_) => {
351 println!(
352 "Failed to decrypt {}, deleting",
353 record.id.0.as_hyphenated()
354 );
355
356 self.delete(record.id).await?;
357 }
358 }
359 }
360
361 Ok(())
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use atuin_common::{
368 record::{DecryptedData, EncryptedData, Host, HostId, Record},
369 utils::uuid_v7,
370 };
371
372 use crate::{
373 encryption::generate_encoded_key,
374 record::{encryption::PASETO_V4, store::Store},
375 settings::test_local_timeout,
376 };
377
378 use super::SqliteStore;
379
380 fn test_record() -> Record<EncryptedData> {
381 Record::builder()
382 .host(Host::new(HostId(atuin_common::utils::uuid_v7())))
383 .version("v1".into())
384 .tag(atuin_common::utils::uuid_v7().simple().to_string())
385 .data(EncryptedData {
386 data: "1234".into(),
387 content_encryption_key: "1234".into(),
388 })
389 .idx(0)
390 .build()
391 }
392
393 #[tokio::test]
394 async fn create_db() {
395 let db = SqliteStore::new(":memory:", test_local_timeout()).await;
396
397 assert!(
398 db.is_ok(),
399 "db could not be created, {:?}",
400 db.err().unwrap()
401 );
402 }
403
404 #[tokio::test]
405 async fn push_record() {
406 let db = SqliteStore::new(":memory:", test_local_timeout())
407 .await
408 .unwrap();
409 let record = test_record();
410
411 db.push(&record).await.expect("failed to insert record");
412 }
413
414 #[tokio::test]
415 async fn get_record() {
416 let db = SqliteStore::new(":memory:", test_local_timeout())
417 .await
418 .unwrap();
419 let record = test_record();
420 db.push(&record).await.unwrap();
421
422 let new_record = db.get(record.id).await.expect("failed to fetch record");
423
424 assert_eq!(record, new_record, "records are not equal");
425 }
426
427 #[tokio::test]
428 async fn last() {
429 let db = SqliteStore::new(":memory:", test_local_timeout())
430 .await
431 .unwrap();
432 let record = test_record();
433 db.push(&record).await.unwrap();
434
435 let last = db
436 .last(record.host.id, record.tag.as_str())
437 .await
438 .expect("failed to get store len");
439
440 assert_eq!(
441 last.unwrap().id,
442 record.id,
443 "expected to get back the same record that was inserted"
444 );
445 }
446
447 #[tokio::test]
448 async fn first() {
449 let db = SqliteStore::new(":memory:", test_local_timeout())
450 .await
451 .unwrap();
452 let record = test_record();
453 db.push(&record).await.unwrap();
454
455 let first = db
456 .first(record.host.id, record.tag.as_str())
457 .await
458 .expect("failed to get store len");
459
460 assert_eq!(
461 first.unwrap().id,
462 record.id,
463 "expected to get back the same record that was inserted"
464 );
465 }
466
467 #[tokio::test]
468 async fn len() {
469 let db = SqliteStore::new(":memory:", test_local_timeout())
470 .await
471 .unwrap();
472 let record = test_record();
473 db.push(&record).await.unwrap();
474
475 let len = db
476 .len(record.host.id, record.tag.as_str())
477 .await
478 .expect("failed to get store len");
479
480 assert_eq!(len, 1, "expected length of 1 after insert");
481 }
482
483 #[tokio::test]
484 async fn len_tag() {
485 let db = SqliteStore::new(":memory:", test_local_timeout())
486 .await
487 .unwrap();
488 let record = test_record();
489 db.push(&record).await.unwrap();
490
491 let len = db
492 .len_tag(record.tag.as_str())
493 .await
494 .expect("failed to get store len");
495
496 assert_eq!(len, 1, "expected length of 1 after insert");
497 }
498
499 #[tokio::test]
500 async fn len_different_tags() {
501 let db = SqliteStore::new(":memory:", test_local_timeout())
502 .await
503 .unwrap();
504
505 let first = test_record();
509 let second = test_record();
510
511 db.push(&first).await.unwrap();
512 db.push(&second).await.unwrap();
513
514 let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap();
515 let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap();
516
517 assert_eq!(first_len, 1, "expected length of 1 after insert");
518 assert_eq!(second_len, 1, "expected length of 1 after insert");
519 }
520
521 #[tokio::test]
522 async fn append_a_bunch() {
523 let db = SqliteStore::new(":memory:", test_local_timeout())
524 .await
525 .unwrap();
526
527 let mut tail = test_record();
528 db.push(&tail).await.expect("failed to push record");
529
530 for _ in 1..100 {
531 tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]);
532 db.push(&tail).await.unwrap();
533 }
534
535 assert_eq!(
536 db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
537 100,
538 "failed to insert 100 records"
539 );
540
541 assert_eq!(
542 db.len_tag(tail.tag.as_str()).await.unwrap(),
543 100,
544 "failed to insert 100 records"
545 );
546 }
547
548 #[tokio::test]
549 async fn append_a_big_bunch() {
550 let db = SqliteStore::new(":memory:", test_local_timeout())
551 .await
552 .unwrap();
553
554 let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(10000);
555
556 let mut tail = test_record();
557 records.push(tail.clone());
558
559 for _ in 1..10000 {
560 tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
561 records.push(tail.clone());
562 }
563
564 db.push_batch(records.iter()).await.unwrap();
565
566 assert_eq!(
567 db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
568 10000,
569 "failed to insert 10k records"
570 );
571 }
572
573 #[tokio::test]
574 async fn re_encrypt() {
575 let store = SqliteStore::new(":memory:", test_local_timeout())
576 .await
577 .unwrap();
578 let (key, _) = generate_encoded_key().unwrap();
579 let data = vec![0u8, 1u8, 2u8, 3u8];
580 let host_id = HostId(uuid_v7());
581
582 for i in 0..10 {
583 let record = Record::builder()
584 .host(Host::new(host_id))
585 .version(String::from("test"))
586 .tag(String::from("test"))
587 .idx(i)
588 .data(DecryptedData(data.clone()))
589 .build();
590
591 let record = record.encrypt::<PASETO_V4>(&key.into());
592 store
593 .push(&record)
594 .await
595 .expect("failed to push encrypted record");
596 }
597
598 let all = store.all_tagged("test").await.unwrap();
600
601 assert_eq!(all.len(), 10, "failed to fetch all records");
602
603 for record in all {
604 let decrypted = record.decrypt::<PASETO_V4>(&key.into()).unwrap();
605 assert_eq!(decrypted.data.0, data);
606 }
607
608 let (new_key, _) = generate_encoded_key().unwrap();
613 store
614 .re_encrypt(&key.into(), &new_key.into())
615 .await
616 .expect("failed to re-encrypt store");
617
618 let all = store.all_tagged("test").await.unwrap();
619
620 for record in all.iter() {
621 let decrypted = record.clone().decrypt::<PASETO_V4>(&key.into());
622 assert!(
623 decrypted.is_err(),
624 "did not get error decrypting with old key after re-encrypt"
625 )
626 }
627
628 for record in all {
629 let decrypted = record.decrypt::<PASETO_V4>(&new_key.into()).unwrap();
630 assert_eq!(decrypted.data.0, data);
631 }
632
633 assert_eq!(store.len(host_id, "test").await.unwrap(), 10);
634 }
635}