1use std::{cmp::Ordering, fmt::Write};
3
4use eyre::Result;
5use thiserror::Error;
6
7use super::store::Store;
8use crate::{api_client::Client, settings::Settings};
9
10use atuin_common::record::{Diff, HostId, RecordId, RecordIdx, RecordStatus};
11use indicatif::{ProgressBar, ProgressState, ProgressStyle};
12
13#[derive(Error, Debug)]
14pub enum SyncError {
15 #[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
16 LocalAheadOtherHost,
17
18 #[error("an issue with the local database occurred: {msg:?}")]
19 LocalStoreError { msg: String },
20
21 #[error("something has gone wrong with the sync logic: {msg:?}")]
22 SyncLogicError { msg: String },
23
24 #[error("operational error: {msg:?}")]
25 OperationalError { msg: String },
26
27 #[error("a request to the sync server failed: {msg:?}")]
28 RemoteRequestError { msg: String },
29}
30
31#[derive(Debug, Eq, PartialEq)]
32pub enum Operation {
33 Upload {
35 local: RecordIdx,
36 remote: Option<RecordIdx>,
37 host: HostId,
38 tag: String,
39 },
40 Download {
41 local: Option<RecordIdx>,
42 remote: RecordIdx,
43 host: HostId,
44 tag: String,
45 },
46 Noop {
47 host: HostId,
48 tag: String,
49 },
50}
51
52pub async fn diff(
53 settings: &Settings,
54 store: &impl Store,
55) -> Result<(Vec<Diff>, RecordStatus), SyncError> {
56 let client = Client::new(
57 &settings.sync_address,
58 settings
59 .session_token()
60 .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?
61 .as_str(),
62 settings.network_connect_timeout,
63 settings.network_timeout,
64 )
65 .map_err(|e| SyncError::OperationalError { msg: e.to_string() })?;
66
67 let local_index = store
68 .status()
69 .await
70 .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
71
72 let remote_index = client
73 .record_status()
74 .await
75 .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
76
77 let diff = local_index.diff(&remote_index);
78
79 Ok((diff, remote_index))
80}
81
82pub async fn operations(
87 diffs: Vec<Diff>,
88 _store: &impl Store,
89) -> Result<Vec<Operation>, SyncError> {
90 let mut operations = Vec::with_capacity(diffs.len());
91
92 for diff in diffs {
93 let op = match (diff.local, diff.remote) {
94 (Some(local), Some(remote)) => match local.cmp(&remote) {
96 Ordering::Equal => Operation::Noop {
97 host: diff.host,
98 tag: diff.tag,
99 },
100 Ordering::Greater => Operation::Upload {
101 local,
102 remote: Some(remote),
103 host: diff.host,
104 tag: diff.tag,
105 },
106 Ordering::Less => Operation::Download {
107 local: Some(local),
108 remote,
109 host: diff.host,
110 tag: diff.tag,
111 },
112 },
113
114 (None, Some(remote)) => Operation::Download {
116 local: None,
117 remote,
118 host: diff.host,
119 tag: diff.tag,
120 },
121
122 (Some(local), None) => Operation::Upload {
124 local,
125 remote: None,
126 host: diff.host,
127 tag: diff.tag,
128 },
129
130 (None, None) => {
132 return Err(SyncError::SyncLogicError {
133 msg: String::from(
134 "diff has nothing for local or remote - (host, tag) does not exist",
135 ),
136 })
137 }
138 };
139
140 operations.push(op);
141 }
142
143 operations.sort_by_key(|op| match op {
149 Operation::Noop { host, tag } => (0, *host, tag.clone()),
150
151 Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
152
153 Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
154 });
155
156 Ok(operations)
157}
158
159async fn sync_upload(
160 store: &impl Store,
161 client: &Client<'_>,
162 host: HostId,
163 tag: String,
164 local: RecordIdx,
165 remote: Option<RecordIdx>,
166) -> Result<i64, SyncError> {
167 let remote = remote.unwrap_or(0);
168 let expected = local - remote;
169 let upload_page_size = 100;
170 let mut progress = 0;
171
172 let pb = ProgressBar::new(expected);
173 pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})")
174 .unwrap()
175 .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap())
176 .progress_chars("#>-"));
177
178 println!(
179 "Uploading {} records to {}/{}",
180 expected,
181 host.0.as_simple(),
182 tag
183 );
184
185 loop {
187 let page = store
188 .next(host, tag.as_str(), remote + progress, upload_page_size)
189 .await
190 .map_err(|e| {
191 error!("failed to read upload page: {e:?}");
192
193 SyncError::LocalStoreError { msg: e.to_string() }
194 })?;
195
196 client.post_records(&page).await.map_err(|e| {
197 error!("failed to post records: {e:?}");
198
199 SyncError::RemoteRequestError { msg: e.to_string() }
200 })?;
201
202 pb.set_position(progress);
203 progress += page.len() as u64;
204
205 if progress >= expected {
206 break;
207 }
208 }
209
210 pb.finish_with_message("Uploaded records");
211
212 Ok(progress as i64)
213}
214
215async fn sync_download(
216 store: &impl Store,
217 client: &Client<'_>,
218 host: HostId,
219 tag: String,
220 local: Option<RecordIdx>,
221 remote: RecordIdx,
222) -> Result<Vec<RecordId>, SyncError> {
223 let local = local.unwrap_or(0);
224 let expected = remote - local;
225 let download_page_size = 100;
226 let mut progress = 0;
227 let mut ret = Vec::new();
228
229 println!(
230 "Downloading {} records from {}/{}",
231 expected,
232 host.0.as_simple(),
233 tag
234 );
235
236 let pb = ProgressBar::new(expected);
237 pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})")
238 .unwrap()
239 .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap())
240 .progress_chars("#>-"));
241
242 loop {
244 let page = client
245 .next_records(host, tag.clone(), local + progress, download_page_size)
246 .await
247 .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
248
249 store
250 .push_batch(page.iter())
251 .await
252 .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
253
254 ret.extend(page.iter().map(|f| f.id));
255
256 pb.set_position(progress);
257 progress += page.len() as u64;
258
259 if progress >= expected {
260 break;
261 }
262 }
263
264 pb.finish_with_message("Downloaded records");
265
266 Ok(ret)
267}
268
269pub async fn sync_remote(
270 operations: Vec<Operation>,
271 local_store: &impl Store,
272 settings: &Settings,
273) -> Result<(i64, Vec<RecordId>), SyncError> {
274 let client = Client::new(
275 &settings.sync_address,
276 settings
277 .session_token()
278 .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?
279 .as_str(),
280 settings.network_connect_timeout,
281 settings.network_timeout,
282 )
283 .expect("failed to create client");
284
285 let mut uploaded = 0;
286 let mut downloaded = Vec::new();
287
288 for i in operations {
290 match i {
291 Operation::Upload {
292 host,
293 tag,
294 local,
295 remote,
296 } => uploaded += sync_upload(local_store, &client, host, tag, local, remote).await?,
297
298 Operation::Download {
299 host,
300 tag,
301 local,
302 remote,
303 } => {
304 let mut d = sync_download(local_store, &client, host, tag, local, remote).await?;
305 downloaded.append(&mut d)
306 }
307
308 Operation::Noop { .. } => continue,
309 }
310 }
311
312 Ok((uploaded, downloaded))
313}
314
315pub async fn sync(
316 settings: &Settings,
317 store: &impl Store,
318) -> Result<(i64, Vec<RecordId>), SyncError> {
319 let (diff, _) = diff(settings, store).await?;
320 let operations = operations(diff, store).await?;
321 let (uploaded, downloaded) = sync_remote(operations, store, settings).await?;
322
323 Ok((uploaded, downloaded))
324}
325
326#[cfg(test)]
327mod tests {
328 use atuin_common::record::{Diff, EncryptedData, HostId, Record};
329 use pretty_assertions::assert_eq;
330
331 use crate::{
332 record::{
333 encryption::PASETO_V4,
334 sqlite_store::SqliteStore,
335 store::Store,
336 sync::{self, Operation},
337 },
338 settings::test_local_timeout,
339 };
340
341 fn test_record() -> Record<EncryptedData> {
342 Record::builder()
343 .host(atuin_common::record::Host::new(HostId(
344 atuin_common::utils::uuid_v7(),
345 )))
346 .version("v1".into())
347 .tag(atuin_common::utils::uuid_v7().simple().to_string())
348 .data(EncryptedData {
349 data: String::new(),
350 content_encryption_key: String::new(),
351 })
352 .idx(0)
353 .build()
354 }
355
356 async fn build_test_diff(
360 local_records: Vec<Record<EncryptedData>>,
361 remote_records: Vec<Record<EncryptedData>>,
362 ) -> (SqliteStore, Vec<Diff>) {
363 let local_store = SqliteStore::new(":memory:", test_local_timeout())
364 .await
365 .expect("failed to open in memory sqlite");
366 let remote_store = SqliteStore::new(":memory:", test_local_timeout())
367 .await
368 .expect("failed to open in memory sqlite"); for i in local_records {
371 local_store.push(&i).await.unwrap();
372 }
373
374 for i in remote_records {
375 remote_store.push(&i).await.unwrap();
376 }
377
378 let local_index = local_store.status().await.unwrap();
379 let remote_index = remote_store.status().await.unwrap();
380
381 let diff = local_index.diff(&remote_index);
382
383 (local_store, diff)
384 }
385
386 #[tokio::test]
387 async fn test_basic_diff() {
388 let record = test_record();
391 let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await;
392
393 assert_eq!(diff.len(), 1);
394
395 let operations = sync::operations(diff, &store).await.unwrap();
396
397 assert_eq!(operations.len(), 1);
398
399 assert_eq!(
400 operations[0],
401 Operation::Upload {
402 host: record.host.id,
403 tag: record.tag,
404 local: record.idx,
405 remote: None,
406 }
407 );
408 }
409
410 #[tokio::test]
411 async fn build_two_way_diff() {
412 let shared_record = test_record();
416 let remote_ahead = test_record();
417
418 let local_ahead = shared_record
419 .append(vec![1, 2, 3])
420 .encrypt::<PASETO_V4>(&[0; 32]);
421
422 assert_eq!(local_ahead.idx, 1);
423
424 let local = vec![shared_record.clone(), local_ahead.clone()]; let remote = vec![shared_record.clone(), remote_ahead.clone()]; let (store, diff) = build_test_diff(local, remote).await;
428 let operations = sync::operations(diff, &store).await.unwrap();
429
430 assert_eq!(operations.len(), 2);
431
432 assert_eq!(
433 operations,
434 vec![
435 Operation::Upload {
437 host: local_ahead.host.id,
438 tag: local_ahead.tag,
439 local: 1,
440 remote: Some(0),
441 },
442 Operation::Download {
444 host: remote_ahead.host.id,
445 tag: remote_ahead.tag,
446 local: None,
447 remote: 0,
448 },
449 ]
450 );
451 }
452
453 #[tokio::test]
454 async fn build_complex_diff() {
455 let shared_record = test_record();
460 let local_only = test_record();
461
462 let local_only_20 = test_record();
463 let local_only_21 = local_only_20
464 .append(vec![1, 2, 3])
465 .encrypt::<PASETO_V4>(&[0; 32]);
466 let local_only_22 = local_only_21
467 .append(vec![1, 2, 3])
468 .encrypt::<PASETO_V4>(&[0; 32]);
469 let local_only_23 = local_only_22
470 .append(vec![1, 2, 3])
471 .encrypt::<PASETO_V4>(&[0; 32]);
472
473 let remote_only = test_record();
474
475 let remote_only_20 = test_record();
476 let remote_only_21 = remote_only_20
477 .append(vec![2, 3, 2])
478 .encrypt::<PASETO_V4>(&[0; 32]);
479 let remote_only_22 = remote_only_21
480 .append(vec![2, 3, 2])
481 .encrypt::<PASETO_V4>(&[0; 32]);
482 let remote_only_23 = remote_only_22
483 .append(vec![2, 3, 2])
484 .encrypt::<PASETO_V4>(&[0; 32]);
485 let remote_only_24 = remote_only_23
486 .append(vec![2, 3, 2])
487 .encrypt::<PASETO_V4>(&[0; 32]);
488
489 let second_shared = test_record();
490 let second_shared_remote_ahead = second_shared
491 .append(vec![1, 2, 3])
492 .encrypt::<PASETO_V4>(&[0; 32]);
493 let second_shared_remote_ahead2 = second_shared_remote_ahead
494 .append(vec![1, 2, 3])
495 .encrypt::<PASETO_V4>(&[0; 32]);
496
497 let third_shared = test_record();
498 let third_shared_local_ahead = third_shared
499 .append(vec![1, 2, 3])
500 .encrypt::<PASETO_V4>(&[0; 32]);
501 let third_shared_local_ahead2 = third_shared_local_ahead
502 .append(vec![1, 2, 3])
503 .encrypt::<PASETO_V4>(&[0; 32]);
504
505 let fourth_shared = test_record();
506 let fourth_shared_remote_ahead = fourth_shared
507 .append(vec![1, 2, 3])
508 .encrypt::<PASETO_V4>(&[0; 32]);
509 let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead
510 .append(vec![1, 2, 3])
511 .encrypt::<PASETO_V4>(&[0; 32]);
512
513 let local = vec![
514 shared_record.clone(),
515 second_shared.clone(),
516 third_shared.clone(),
517 fourth_shared.clone(),
518 fourth_shared_remote_ahead.clone(),
519 local_only.clone(),
521 local_only_20.clone(),
523 local_only_21.clone(),
524 local_only_22.clone(),
525 local_only_23.clone(),
526 third_shared_local_ahead.clone(),
528 third_shared_local_ahead2.clone(),
529 ];
530
531 let remote = vec![
532 remote_only.clone(),
533 remote_only_20.clone(),
534 remote_only_21.clone(),
535 remote_only_22.clone(),
536 remote_only_23.clone(),
537 remote_only_24.clone(),
538 shared_record.clone(),
539 second_shared.clone(),
540 third_shared.clone(),
541 second_shared_remote_ahead.clone(),
542 second_shared_remote_ahead2.clone(),
543 fourth_shared.clone(),
544 fourth_shared_remote_ahead.clone(),
545 fourth_shared_remote_ahead2.clone(),
546 ]; let (store, diff) = build_test_diff(local, remote).await;
549 let operations = sync::operations(diff, &store).await.unwrap();
550
551 assert_eq!(operations.len(), 7);
552
553 let mut result_ops = vec![
554 Operation::Download {
557 local: Some(0),
558 remote: 2,
559 host: second_shared_remote_ahead.host.id,
560 tag: second_shared_remote_ahead.tag,
561 },
562 Operation::Download {
564 local: Some(1),
565 remote: 2,
566 host: fourth_shared_remote_ahead2.host.id,
567 tag: fourth_shared_remote_ahead2.tag,
568 },
569 Operation::Download {
571 local: None,
572 remote: 0,
573 host: remote_only.host.id,
574 tag: remote_only.tag,
575 },
576 Operation::Download {
578 local: None,
579 remote: 4,
580 host: remote_only_20.host.id,
581 tag: remote_only_20.tag,
582 },
583 Operation::Upload {
585 local: 0,
586 remote: None,
587 host: local_only.host.id,
588 tag: local_only.tag,
589 },
590 Operation::Upload {
592 local: 3,
593 remote: None,
594 host: local_only_20.host.id,
595 tag: local_only_20.tag,
596 },
597 Operation::Upload {
599 local: 2,
600 remote: Some(0),
601 host: third_shared.host.id,
602 tag: third_shared.tag,
603 },
604 ];
605
606 result_ops.sort_by_key(|op| match op {
607 Operation::Noop { host, tag } => (0, *host, tag.clone()),
608
609 Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
610
611 Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
612 });
613
614 assert_eq!(result_ops, operations);
615 }
616}