1use std::{
2 borrow::Cow,
3 env,
4 path::{Path, PathBuf},
5 str::FromStr,
6 time::Duration,
7};
8
9use async_trait::async_trait;
10use atuin_common::utils;
11use fs_err as fs;
12use itertools::Itertools;
13use rand::{distributions::Alphanumeric, Rng};
14use sql_builder::{bind::Bind, esc, quote, SqlBuilder, SqlName};
15use sqlx::{
16 sqlite::{
17 SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
18 SqliteSynchronous,
19 },
20 Result, Row,
21};
22use time::OffsetDateTime;
23
24use crate::{
25 history::{HistoryId, HistoryStats},
26 utils::get_host_user,
27};
28
29use super::{
30 history::History,
31 ordering,
32 settings::{FilterMode, SearchMode, Settings},
33};
34
35pub struct Context {
36 pub session: String,
37 pub cwd: String,
38 pub hostname: String,
39 pub host_id: String,
40 pub git_root: Option<PathBuf>,
41}
42
43#[derive(Default, Clone)]
44pub struct OptFilters {
45 pub exit: Option<i64>,
46 pub exclude_exit: Option<i64>,
47 pub cwd: Option<String>,
48 pub exclude_cwd: Option<String>,
49 pub before: Option<String>,
50 pub after: Option<String>,
51 pub limit: Option<i64>,
52 pub offset: Option<i64>,
53 pub reverse: bool,
54}
55
56pub fn current_context() -> Context {
57 let Ok(session) = env::var("ATUIN_SESSION") else {
58 eprintln!("ERROR: Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell.");
59 std::process::exit(1);
60 };
61 let hostname = get_host_user();
62 let cwd = utils::get_current_dir();
63 let host_id = Settings::host_id().expect("failed to load host ID");
64 let git_root = utils::in_git_repo(cwd.as_str());
65
66 Context {
67 session,
68 hostname,
69 cwd,
70 git_root,
71 host_id: host_id.0.as_simple().to_string(),
72 }
73}
74
75#[async_trait]
76pub trait Database: Send + Sync + 'static {
77 async fn save(&self, h: &History) -> Result<()>;
78 async fn save_bulk(&self, h: &[History]) -> Result<()>;
79
80 async fn load(&self, id: &str) -> Result<Option<History>>;
81 async fn list(
82 &self,
83 filters: &[FilterMode],
84 context: &Context,
85 max: Option<usize>,
86 unique: bool,
87 include_deleted: bool,
88 ) -> Result<Vec<History>>;
89 async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>;
90
91 async fn update(&self, h: &History) -> Result<()>;
92 async fn history_count(&self, include_deleted: bool) -> Result<i64>;
93
94 async fn last(&self) -> Result<Option<History>>;
95 async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>;
96
97 async fn delete(&self, h: History) -> Result<()>;
98 async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>;
99 async fn deleted(&self) -> Result<Vec<History>>;
100
101 #[allow(clippy::too_many_arguments)]
105 async fn search(
106 &self,
107 search_mode: SearchMode,
108 filter: FilterMode,
109 context: &Context,
110 query: &str,
111 filter_options: OptFilters,
112 ) -> Result<Vec<History>>;
113
114 async fn query_history(&self, query: &str) -> Result<Vec<History>>;
115
116 async fn all_with_count(&self) -> Result<Vec<(History, i32)>>;
117
118 async fn stats(&self, h: &History) -> Result<HistoryStats>;
119}
120
121#[derive(Debug, Clone)]
124pub struct Sqlite {
125 pub pool: SqlitePool,
126}
127
128impl Sqlite {
129 pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
130 let path = path.as_ref();
131 debug!("opening sqlite database at {:?}", path);
132
133 let create = !path.exists();
134 if create {
135 if let Some(dir) = path.parent() {
136 fs::create_dir_all(dir)?;
137 }
138 }
139
140 let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
141 .journal_mode(SqliteJournalMode::Wal)
142 .optimize_on_close(true, None)
143 .synchronous(SqliteSynchronous::Normal)
144 .with_regexp()
145 .create_if_missing(true);
146
147 let pool = SqlitePoolOptions::new()
148 .acquire_timeout(Duration::from_secs_f64(timeout))
149 .connect_with(opts)
150 .await?;
151
152 Self::setup_db(&pool).await?;
153
154 Ok(Self { pool })
155 }
156
157 pub async fn sqlite_version(&self) -> Result<String> {
158 sqlx::query_scalar("SELECT sqlite_version()")
159 .fetch_one(&self.pool)
160 .await
161 }
162
163 async fn setup_db(pool: &SqlitePool) -> Result<()> {
164 debug!("running sqlite database setup");
165
166 sqlx::migrate!("./migrations").run(pool).await?;
167
168 Ok(())
169 }
170
171 async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> {
172 sqlx::query(
173 "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, deleted_at)
174 values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
175 )
176 .bind(h.id.0.as_str())
177 .bind(h.timestamp.unix_timestamp_nanos() as i64)
178 .bind(h.duration)
179 .bind(h.exit)
180 .bind(h.command.as_str())
181 .bind(h.cwd.as_str())
182 .bind(h.session.as_str())
183 .bind(h.hostname.as_str())
184 .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
185 .execute(&mut **tx)
186 .await?;
187
188 Ok(())
189 }
190
191 async fn delete_row_raw(
192 tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
193 id: HistoryId,
194 ) -> Result<()> {
195 sqlx::query("delete from history where id = ?1")
196 .bind(id.0.as_str())
197 .execute(&mut **tx)
198 .await?;
199
200 Ok(())
201 }
202
203 fn query_history(row: SqliteRow) -> History {
204 let deleted_at: Option<i64> = row.get("deleted_at");
205
206 History::from_db()
207 .id(row.get("id"))
208 .timestamp(
209 OffsetDateTime::from_unix_timestamp_nanos(row.get::<i64, _>("timestamp") as i128)
210 .unwrap(),
211 )
212 .duration(row.get("duration"))
213 .exit(row.get("exit"))
214 .command(row.get("command"))
215 .cwd(row.get("cwd"))
216 .session(row.get("session"))
217 .hostname(row.get("hostname"))
218 .deleted_at(
219 deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()),
220 )
221 .build()
222 .into()
223 }
224}
225
226#[async_trait]
227impl Database for Sqlite {
228 async fn save(&self, h: &History) -> Result<()> {
229 debug!("saving history to sqlite");
230 let mut tx = self.pool.begin().await?;
231 Self::save_raw(&mut tx, h).await?;
232 tx.commit().await?;
233
234 Ok(())
235 }
236
237 async fn save_bulk(&self, h: &[History]) -> Result<()> {
238 debug!("saving history to sqlite");
239
240 let mut tx = self.pool.begin().await?;
241
242 for i in h {
243 Self::save_raw(&mut tx, i).await?;
244 }
245
246 tx.commit().await?;
247
248 Ok(())
249 }
250
251 async fn load(&self, id: &str) -> Result<Option<History>> {
252 debug!("loading history item {}", id);
253
254 let res = sqlx::query("select * from history where id = ?1")
255 .bind(id)
256 .map(Self::query_history)
257 .fetch_optional(&self.pool)
258 .await?;
259
260 Ok(res)
261 }
262
263 async fn update(&self, h: &History) -> Result<()> {
264 debug!("updating sqlite history");
265
266 sqlx::query(
267 "update history
268 set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, deleted_at = ?9
269 where id = ?1",
270 )
271 .bind(h.id.0.as_str())
272 .bind(h.timestamp.unix_timestamp_nanos() as i64)
273 .bind(h.duration)
274 .bind(h.exit)
275 .bind(h.command.as_str())
276 .bind(h.cwd.as_str())
277 .bind(h.session.as_str())
278 .bind(h.hostname.as_str())
279 .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
280 .execute(&self.pool)
281 .await?;
282
283 Ok(())
284 }
285
286 async fn list(
288 &self,
289 filters: &[FilterMode],
290 context: &Context,
291 max: Option<usize>,
292 unique: bool,
293 include_deleted: bool,
294 ) -> Result<Vec<History>> {
295 debug!("listing history");
296
297 let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
298 query.field("*").order_desc("timestamp");
299 if !include_deleted {
300 query.and_where_is_null("deleted_at");
301 }
302
303 let git_root = if let Some(git_root) = context.git_root.clone() {
304 git_root.to_str().unwrap_or("/").to_string()
305 } else {
306 context.cwd.clone()
307 };
308
309 for filter in filters {
310 match filter {
311 FilterMode::Global => &mut query,
312 FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)),
313 FilterMode::Session => query.and_where_eq("session", quote(&context.session)),
314 FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)),
315 FilterMode::Workspace => query.and_where_like_left("cwd", &git_root),
316 };
317 }
318
319 if unique {
320 query.group_by("command").having("max(timestamp)");
321 }
322
323 if let Some(max) = max {
324 query.limit(max);
325 }
326
327 let query = query.sql().expect("bug in list query. please report");
328
329 let res = sqlx::query(&query)
330 .map(Self::query_history)
331 .fetch_all(&self.pool)
332 .await?;
333
334 Ok(res)
335 }
336
337 async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> {
338 debug!("listing history from {:?} to {:?}", from, to);
339
340 let res = sqlx::query(
341 "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc",
342 )
343 .bind(from.unix_timestamp_nanos() as i64)
344 .bind(to.unix_timestamp_nanos() as i64)
345 .map(Self::query_history)
346 .fetch_all(&self.pool)
347 .await?;
348
349 Ok(res)
350 }
351
352 async fn last(&self) -> Result<Option<History>> {
353 let res = sqlx::query(
354 "select * from history where duration >= 0 order by timestamp desc limit 1",
355 )
356 .map(Self::query_history)
357 .fetch_optional(&self.pool)
358 .await?;
359
360 Ok(res)
361 }
362
363 async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> {
364 let res = sqlx::query(
365 "select * from history where timestamp < ?1 order by timestamp desc limit ?2",
366 )
367 .bind(timestamp.unix_timestamp_nanos() as i64)
368 .bind(count)
369 .map(Self::query_history)
370 .fetch_all(&self.pool)
371 .await?;
372
373 Ok(res)
374 }
375
376 async fn deleted(&self) -> Result<Vec<History>> {
377 let res = sqlx::query("select * from history where deleted_at is not null")
378 .map(Self::query_history)
379 .fetch_all(&self.pool)
380 .await?;
381
382 Ok(res)
383 }
384
385 async fn history_count(&self, include_deleted: bool) -> Result<i64> {
386 let query = if include_deleted {
387 "select count(1) from history"
388 } else {
389 "select count(1) from history where deleted_at is null"
390 };
391
392 let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?;
393 Ok(res.0)
394 }
395
396 async fn search(
397 &self,
398 search_mode: SearchMode,
399 filter: FilterMode,
400 context: &Context,
401 query: &str,
402 filter_options: OptFilters,
403 ) -> Result<Vec<History>> {
404 let mut sql = SqlBuilder::select_from("history");
405
406 sql.group_by("command").having("max(timestamp)");
407
408 if let Some(limit) = filter_options.limit {
409 sql.limit(limit);
410 }
411
412 if let Some(offset) = filter_options.offset {
413 sql.offset(offset);
414 }
415
416 if filter_options.reverse {
417 sql.order_asc("timestamp");
418 } else {
419 sql.order_desc("timestamp");
420 }
421
422 let git_root = if let Some(git_root) = context.git_root.clone() {
423 git_root.to_str().unwrap_or("/").to_string()
424 } else {
425 context.cwd.clone()
426 };
427
428 match filter {
429 FilterMode::Global => &mut sql,
430 FilterMode::Host => {
431 sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase()))
432 }
433 FilterMode::Session => sql.and_where_eq("session", quote(&context.session)),
434 FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)),
435 FilterMode::Workspace => sql.and_where_like_left("cwd", git_root),
436 };
437
438 let orig_query = query;
439
440 let mut regexes = Vec::new();
441 match search_mode {
442 SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")),
443 _ => {
444 let mut is_or = false;
445 let mut regex = None;
446 for part in query.split_inclusive(' ') {
447 let query_part: Cow<str> = match (&mut regex, part.starts_with("r/")) {
448 (None, false) => {
449 if part.trim_end().is_empty() {
450 continue;
451 }
452 Cow::Owned(part.trim_end().replace('*', "%")) }
454 (None, true) => {
455 if part[2..].trim_end().ends_with('/') {
456 let end_pos = part.trim_end().len() - 1;
457 regexes.push(String::from(&part[2..end_pos]));
458 } else {
459 regex = Some(String::from(&part[2..]));
460 }
461 continue;
462 }
463 (Some(r), _) => {
464 if part.trim_end().ends_with('/') {
465 let end_pos = part.trim_end().len() - 1;
466 r.push_str(&part.trim_end()[..end_pos]);
467 regexes.push(regex.take().unwrap());
468 } else {
469 r.push_str(part);
470 }
471 continue;
472 }
473 };
474
475 let (is_glob, glob) = if query_part.contains(char::is_uppercase) {
477 (true, "*")
478 } else {
479 (false, "%")
480 };
481
482 let (is_inverse, query_part) = match query_part.strip_prefix('!') {
483 Some(stripped) => (true, Cow::Borrowed(stripped)),
484 None => (false, query_part),
485 };
486
487 #[allow(clippy::if_same_then_else)]
488 let param = if query_part == "|" {
489 if !is_or {
490 is_or = true;
491 continue;
492 } else {
493 format!("{glob}|{glob}")
494 }
495 } else if let Some(term) = query_part.strip_prefix('^') {
496 format!("{term}{glob}")
497 } else if let Some(term) = query_part.strip_suffix('$') {
498 format!("{glob}{term}")
499 } else if let Some(term) = query_part.strip_prefix('\'') {
500 format!("{glob}{term}{glob}")
501 } else if is_inverse {
502 format!("{glob}{query_part}{glob}")
503 } else if search_mode == SearchMode::FullText {
504 format!("{glob}{query_part}{glob}")
505 } else {
506 query_part.split("").join(glob)
507 };
508
509 sql.fuzzy_condition("command", param, is_inverse, is_glob, is_or);
510 is_or = false;
511 }
512 if let Some(r) = regex {
513 regexes.push(r);
514 }
515
516 &mut sql
517 }
518 };
519
520 for regex in regexes {
521 sql.and_where("command regexp ?".bind(®ex));
522 }
523
524 filter_options
525 .exit
526 .map(|exit| sql.and_where_eq("exit", exit));
527
528 filter_options
529 .exclude_exit
530 .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit));
531
532 filter_options
533 .cwd
534 .map(|cwd| sql.and_where_eq("cwd", quote(cwd)));
535
536 filter_options
537 .exclude_cwd
538 .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd)));
539
540 filter_options.before.map(|before| {
541 interim::parse_date_string(
542 before.as_str(),
543 OffsetDateTime::now_utc(),
544 interim::Dialect::Uk,
545 )
546 .map(|before| {
547 sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64))
548 })
549 });
550
551 filter_options.after.map(|after| {
552 interim::parse_date_string(
553 after.as_str(),
554 OffsetDateTime::now_utc(),
555 interim::Dialect::Uk,
556 )
557 .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64)))
558 });
559
560 sql.and_where_is_null("deleted_at");
561
562 let query = sql.sql().expect("bug in search query. please report");
563
564 let res = sqlx::query(&query)
565 .map(Self::query_history)
566 .fetch_all(&self.pool)
567 .await?;
568
569 Ok(ordering::reorder_fuzzy(search_mode, orig_query, res))
570 }
571
572 async fn query_history(&self, query: &str) -> Result<Vec<History>> {
573 let res = sqlx::query(query)
574 .map(Self::query_history)
575 .fetch_all(&self.pool)
576 .await?;
577
578 Ok(res)
579 }
580
581 async fn all_with_count(&self) -> Result<Vec<(History, i32)>> {
582 debug!("listing history");
583
584 let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
585
586 query
587 .fields(&[
588 "id",
589 "max(timestamp) as timestamp",
590 "max(duration) as duration",
591 "exit",
592 "command",
593 "deleted_at",
594 "group_concat(cwd, ':') as cwd",
595 "group_concat(session) as session",
596 "group_concat(hostname, ',') as hostname",
597 "count(*) as count",
598 ])
599 .group_by("command")
600 .group_by("exit")
601 .and_where("deleted_at is null")
602 .order_desc("timestamp");
603
604 let query = query.sql().expect("bug in list query. please report");
605
606 let res = sqlx::query(&query)
607 .map(|row: SqliteRow| {
608 let count: i32 = row.get("count");
609 (Self::query_history(row), count)
610 })
611 .fetch_all(&self.pool)
612 .await?;
613
614 Ok(res)
615 }
616
617 async fn delete(&self, mut h: History) -> Result<()> {
620 let now = OffsetDateTime::now_utc();
621 h.command = rand::thread_rng()
622 .sample_iter(&Alphanumeric)
623 .take(32)
624 .map(char::from)
625 .collect(); h.deleted_at = Some(now); self.update(&h).await?; Ok(())
631 }
632
633 async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> {
634 let mut tx = self.pool.begin().await?;
635
636 for id in ids {
637 Self::delete_row_raw(&mut tx, id.clone()).await?;
638 }
639
640 tx.commit().await?;
641
642 Ok(())
643 }
644
645 async fn stats(&self, h: &History) -> Result<HistoryStats> {
646 let mut prev = SqlBuilder::select_from("history");
648 prev.field("*")
649 .and_where("timestamp < ?1")
650 .and_where("session = ?2")
651 .order_by("timestamp", true)
652 .limit(1);
653
654 let mut next = SqlBuilder::select_from("history");
655 next.field("*")
656 .and_where("timestamp > ?1")
657 .and_where("session = ?2")
658 .order_by("timestamp", false)
659 .limit(1);
660
661 let mut total = SqlBuilder::select_from("history");
662 total.field("count(1)").and_where("command = ?1");
663
664 let mut average = SqlBuilder::select_from("history");
665 average.field("avg(duration)").and_where("command = ?1");
666
667 let mut exits = SqlBuilder::select_from("history");
668 exits
669 .fields(&["exit", "count(1) as count"])
670 .and_where("command = ?1")
671 .group_by("exit");
672
673 let mut day_of_week = SqlBuilder::select_from("history");
675 day_of_week
676 .fields(&[
677 "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week",
678 "count(1) as count",
679 ])
680 .and_where("command = ?1")
681 .group_by("day_of_week");
682
683 let mut duration_over_time = SqlBuilder::select_from("history");
688 duration_over_time
689 .fields(&[
690 "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year",
691 "avg(duration) as duration",
692 ])
693 .and_where("command = ?1")
694 .group_by("month_year")
695 .having("duration > 0");
696
697 let prev = prev.sql().expect("issue in stats previous query");
698 let next = next.sql().expect("issue in stats next query");
699 let total = total.sql().expect("issue in stats average query");
700 let average = average.sql().expect("issue in stats previous query");
701 let exits = exits.sql().expect("issue in stats exits query");
702 let day_of_week = day_of_week.sql().expect("issue in stats day of week query");
703 let duration_over_time = duration_over_time
704 .sql()
705 .expect("issue in stats duration over time query");
706
707 let prev = sqlx::query(&prev)
708 .bind(h.timestamp.unix_timestamp_nanos() as i64)
709 .bind(&h.session)
710 .map(Self::query_history)
711 .fetch_optional(&self.pool)
712 .await?;
713
714 let next = sqlx::query(&next)
715 .bind(h.timestamp.unix_timestamp_nanos() as i64)
716 .bind(&h.session)
717 .map(Self::query_history)
718 .fetch_optional(&self.pool)
719 .await?;
720
721 let total: (i64,) = sqlx::query_as(&total)
722 .bind(&h.command)
723 .fetch_one(&self.pool)
724 .await?;
725
726 let average: (f64,) = sqlx::query_as(&average)
727 .bind(&h.command)
728 .fetch_one(&self.pool)
729 .await?;
730
731 let exits: Vec<(i64, i64)> = sqlx::query_as(&exits)
732 .bind(&h.command)
733 .fetch_all(&self.pool)
734 .await?;
735
736 let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week)
737 .bind(&h.command)
738 .fetch_all(&self.pool)
739 .await?;
740
741 let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time)
742 .bind(&h.command)
743 .fetch_all(&self.pool)
744 .await?;
745
746 let duration_over_time = duration_over_time
747 .iter()
748 .map(|f| (f.0.clone(), f.1.round() as i64))
749 .collect();
750
751 Ok(HistoryStats {
752 next,
753 previous: prev,
754 total: total.0 as u64,
755 average_duration: average.0 as u64,
756 exits,
757 day_of_week,
758 duration_over_time,
759 })
760 }
761}
762
763trait SqlBuilderExt {
764 fn fuzzy_condition<S: ToString, T: ToString>(
765 &mut self,
766 field: S,
767 mask: T,
768 inverse: bool,
769 glob: bool,
770 is_or: bool,
771 ) -> &mut Self;
772}
773
774impl SqlBuilderExt for SqlBuilder {
775 fn fuzzy_condition<S: ToString, T: ToString>(
777 &mut self,
778 field: S,
779 mask: T,
780 inverse: bool,
781 glob: bool,
782 is_or: bool,
783 ) -> &mut Self {
784 let mut cond = field.to_string();
785 if inverse {
786 cond.push_str(" NOT");
787 }
788 if glob {
789 cond.push_str(" GLOB '");
790 } else {
791 cond.push_str(" LIKE '");
792 }
793 cond.push_str(&esc(mask.to_string()));
794 cond.push('\'');
795 if is_or {
796 self.or_where(cond)
797 } else {
798 self.and_where(cond)
799 }
800 }
801}
802
803#[cfg(test)]
804mod test {
805 use crate::settings::test_local_timeout;
806
807 use super::*;
808 use std::time::{Duration, Instant};
809
810 async fn assert_search_eq<'a>(
811 db: &impl Database,
812 mode: SearchMode,
813 filter_mode: FilterMode,
814 query: &str,
815 expected: usize,
816 ) -> Result<Vec<History>> {
817 let context = Context {
818 hostname: "test:host".to_string(),
819 session: "beepboopiamasession".to_string(),
820 cwd: "/home/ellie".to_string(),
821 host_id: "test-host".to_string(),
822 git_root: None,
823 };
824
825 let results = db
826 .search(
827 mode,
828 filter_mode,
829 &context,
830 query,
831 OptFilters {
832 ..Default::default()
833 },
834 )
835 .await?;
836
837 assert_eq!(
838 results.len(),
839 expected,
840 "query \"{}\", commands: {:?}",
841 query,
842 results.iter().map(|a| &a.command).collect::<Vec<&String>>()
843 );
844 Ok(results)
845 }
846
847 async fn assert_search_commands(
848 db: &impl Database,
849 mode: SearchMode,
850 filter_mode: FilterMode,
851 query: &str,
852 expected_commands: Vec<&str>,
853 ) {
854 let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len())
855 .await
856 .unwrap();
857 let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect();
858 assert_eq!(commands, expected_commands);
859 }
860
861 async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> {
862 let mut captured: History = History::capture()
863 .timestamp(OffsetDateTime::now_utc())
864 .command(cmd)
865 .cwd("/home/ellie")
866 .build()
867 .into();
868
869 captured.exit = 0;
870 captured.duration = 1;
871 captured.session = "beep boop".to_string();
872 captured.hostname = "booop".to_string();
873
874 db.save(&captured).await
875 }
876
877 #[tokio::test(flavor = "multi_thread")]
878 async fn test_search_prefix() {
879 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
880 .await
881 .unwrap();
882 new_history_item(&mut db, "ls /home/ellie").await.unwrap();
883
884 assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1)
885 .await
886 .unwrap();
887 assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0)
888 .await
889 .unwrap();
890 assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0)
891 .await
892 .unwrap();
893 }
894
895 #[tokio::test(flavor = "multi_thread")]
896 async fn test_search_fulltext() {
897 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
898 .await
899 .unwrap();
900 new_history_item(&mut db, "ls /home/ellie").await.unwrap();
901
902 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1)
903 .await
904 .unwrap();
905 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1)
906 .await
907 .unwrap();
908 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1)
909 .await
910 .unwrap();
911 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0)
912 .await
913 .unwrap();
914
915 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1)
917 .await
918 .unwrap();
919 assert_search_eq(
920 &db,
921 SearchMode::FullText,
922 FilterMode::Global,
923 "r/ls / ie$",
924 1,
925 )
926 .await
927 .unwrap();
928 assert_search_eq(
929 &db,
930 SearchMode::FullText,
931 FilterMode::Global,
932 "r/ls / !ie",
933 0,
934 )
935 .await
936 .unwrap();
937 assert_search_eq(
938 &db,
939 SearchMode::FullText,
940 FilterMode::Global,
941 "meow r/ls/",
942 0,
943 )
944 .await
945 .unwrap();
946 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1)
947 .await
948 .unwrap();
949 assert_search_eq(
950 &db,
951 SearchMode::FullText,
952 FilterMode::Global,
953 "r//home//",
954 1,
955 )
956 .await
957 .unwrap();
958 assert_search_eq(
959 &db,
960 SearchMode::FullText,
961 FilterMode::Global,
962 "r//home///",
963 0,
964 )
965 .await
966 .unwrap();
967 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0)
968 .await
969 .unwrap();
970 assert_search_eq(
971 &db,
972 SearchMode::FullText,
973 FilterMode::Global,
974 "r/home.*e",
975 1,
976 )
977 .await
978 .unwrap();
979 }
980
981 #[tokio::test(flavor = "multi_thread")]
982 async fn test_search_fuzzy() {
983 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
984 .await
985 .unwrap();
986 new_history_item(&mut db, "ls /home/ellie").await.unwrap();
987 new_history_item(&mut db, "ls /home/frank").await.unwrap();
988 new_history_item(&mut db, "cd /home/Ellie").await.unwrap();
989 new_history_item(&mut db, "/home/ellie/.bin/rustup")
990 .await
991 .unwrap();
992
993 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3)
994 .await
995 .unwrap();
996 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2)
997 .await
998 .unwrap();
999 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2)
1000 .await
1001 .unwrap();
1002 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3)
1003 .await
1004 .unwrap();
1005 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0)
1006 .await
1007 .unwrap();
1008 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0)
1009 .await
1010 .unwrap();
1011 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1)
1012 .await
1013 .unwrap();
1014 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4)
1015 .await
1016 .unwrap();
1017
1018 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2)
1020 .await
1021 .unwrap();
1022 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2)
1023 .await
1024 .unwrap();
1025 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2)
1026 .await
1027 .unwrap();
1028 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2)
1029 .await
1030 .unwrap();
1031 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1)
1032 .await
1033 .unwrap();
1034 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2)
1035 .await
1036 .unwrap();
1037
1038 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1)
1040 .await
1041 .unwrap();
1042 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1)
1043 .await
1044 .unwrap();
1045 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2)
1046 .await
1047 .unwrap();
1048 assert_search_eq(
1049 &db,
1050 SearchMode::Fuzzy,
1051 FilterMode::Global,
1052 "'frank | 'rustup",
1053 2,
1054 )
1055 .await
1056 .unwrap();
1057 assert_search_eq(
1058 &db,
1059 SearchMode::Fuzzy,
1060 FilterMode::Global,
1061 "'frank | 'rustup 'ls",
1062 1,
1063 )
1064 .await
1065 .unwrap();
1066
1067 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1)
1069 .await
1070 .unwrap();
1071
1072 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2)
1074 .await
1075 .unwrap();
1076 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3)
1077 .await
1078 .unwrap();
1079 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1)
1080 .await
1081 .unwrap();
1082 }
1083
1084 #[tokio::test(flavor = "multi_thread")]
1085 async fn test_search_reordered_fuzzy() {
1086 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1087 .await
1088 .unwrap();
1089 new_history_item(&mut db, "curl").await.unwrap();
1092 new_history_item(&mut db, "corburl").await.unwrap();
1093
1094 assert_search_commands(
1096 &db,
1097 SearchMode::Fuzzy,
1098 FilterMode::Global,
1099 "curl",
1100 vec!["curl", "corburl"],
1101 )
1102 .await;
1103
1104 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0)
1105 .await
1106 .unwrap();
1107 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2)
1108 .await
1109 .unwrap();
1110 }
1111
1112 #[tokio::test(flavor = "multi_thread")]
1113 async fn test_search_bench_dupes() {
1114 let context = Context {
1115 hostname: "test:host".to_string(),
1116 session: "beepboopiamasession".to_string(),
1117 cwd: "/home/ellie".to_string(),
1118 host_id: "test-host".to_string(),
1119 git_root: None,
1120 };
1121
1122 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1123 .await
1124 .unwrap();
1125 for _i in 1..10000 {
1126 new_history_item(&mut db, "i am a duplicated command")
1127 .await
1128 .unwrap();
1129 }
1130 let start = Instant::now();
1131 let _results = db
1132 .search(
1133 SearchMode::Fuzzy,
1134 FilterMode::Global,
1135 &context,
1136 "",
1137 OptFilters {
1138 ..Default::default()
1139 },
1140 )
1141 .await
1142 .unwrap();
1143 let duration = start.elapsed();
1144
1145 assert!(duration < Duration::from_secs(15));
1146 }
1147}