atuin_client/
database.rs

1use std::{
2    borrow::Cow,
3    env,
4    path::{Path, PathBuf},
5    str::FromStr,
6    time::Duration,
7};
8
9use async_trait::async_trait;
10use atuin_common::utils;
11use fs_err as fs;
12use itertools::Itertools;
13use rand::{distributions::Alphanumeric, Rng};
14use sql_builder::{bind::Bind, esc, quote, SqlBuilder, SqlName};
15use sqlx::{
16    sqlite::{
17        SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
18        SqliteSynchronous,
19    },
20    Result, Row,
21};
22use time::OffsetDateTime;
23
24use crate::{
25    history::{HistoryId, HistoryStats},
26    utils::get_host_user,
27};
28
29use super::{
30    history::History,
31    ordering,
32    settings::{FilterMode, SearchMode, Settings},
33};
34
35pub struct Context {
36    pub session: String,
37    pub cwd: String,
38    pub hostname: String,
39    pub host_id: String,
40    pub git_root: Option<PathBuf>,
41}
42
43#[derive(Default, Clone)]
44pub struct OptFilters {
45    pub exit: Option<i64>,
46    pub exclude_exit: Option<i64>,
47    pub cwd: Option<String>,
48    pub exclude_cwd: Option<String>,
49    pub before: Option<String>,
50    pub after: Option<String>,
51    pub limit: Option<i64>,
52    pub offset: Option<i64>,
53    pub reverse: bool,
54}
55
56pub fn current_context() -> Context {
57    let Ok(session) = env::var("ATUIN_SESSION") else {
58        eprintln!("ERROR: Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell.");
59        std::process::exit(1);
60    };
61    let hostname = get_host_user();
62    let cwd = utils::get_current_dir();
63    let host_id = Settings::host_id().expect("failed to load host ID");
64    let git_root = utils::in_git_repo(cwd.as_str());
65
66    Context {
67        session,
68        hostname,
69        cwd,
70        git_root,
71        host_id: host_id.0.as_simple().to_string(),
72    }
73}
74
75#[async_trait]
76pub trait Database: Send + Sync + 'static {
77    async fn save(&self, h: &History) -> Result<()>;
78    async fn save_bulk(&self, h: &[History]) -> Result<()>;
79
80    async fn load(&self, id: &str) -> Result<Option<History>>;
81    async fn list(
82        &self,
83        filters: &[FilterMode],
84        context: &Context,
85        max: Option<usize>,
86        unique: bool,
87        include_deleted: bool,
88    ) -> Result<Vec<History>>;
89    async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>;
90
91    async fn update(&self, h: &History) -> Result<()>;
92    async fn history_count(&self, include_deleted: bool) -> Result<i64>;
93
94    async fn last(&self) -> Result<Option<History>>;
95    async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>;
96
97    async fn delete(&self, h: History) -> Result<()>;
98    async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>;
99    async fn deleted(&self) -> Result<Vec<History>>;
100
101    // Yes I know, it's a lot.
102    // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless.
103    // Been debating maybe a DSL for search? eg "before:time limit:1 the query"
104    #[allow(clippy::too_many_arguments)]
105    async fn search(
106        &self,
107        search_mode: SearchMode,
108        filter: FilterMode,
109        context: &Context,
110        query: &str,
111        filter_options: OptFilters,
112    ) -> Result<Vec<History>>;
113
114    async fn query_history(&self, query: &str) -> Result<Vec<History>>;
115
116    async fn all_with_count(&self) -> Result<Vec<(History, i32)>>;
117
118    async fn stats(&self, h: &History) -> Result<HistoryStats>;
119}
120
121// Intended for use on a developer machine and not a sync server.
122// TODO: implement IntoIterator
123#[derive(Debug, Clone)]
124pub struct Sqlite {
125    pub pool: SqlitePool,
126}
127
128impl Sqlite {
129    pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
130        let path = path.as_ref();
131        debug!("opening sqlite database at {:?}", path);
132
133        let create = !path.exists();
134        if create {
135            if let Some(dir) = path.parent() {
136                fs::create_dir_all(dir)?;
137            }
138        }
139
140        let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
141            .journal_mode(SqliteJournalMode::Wal)
142            .optimize_on_close(true, None)
143            .synchronous(SqliteSynchronous::Normal)
144            .with_regexp()
145            .create_if_missing(true);
146
147        let pool = SqlitePoolOptions::new()
148            .acquire_timeout(Duration::from_secs_f64(timeout))
149            .connect_with(opts)
150            .await?;
151
152        Self::setup_db(&pool).await?;
153
154        Ok(Self { pool })
155    }
156
157    pub async fn sqlite_version(&self) -> Result<String> {
158        sqlx::query_scalar("SELECT sqlite_version()")
159            .fetch_one(&self.pool)
160            .await
161    }
162
163    async fn setup_db(pool: &SqlitePool) -> Result<()> {
164        debug!("running sqlite database setup");
165
166        sqlx::migrate!("./migrations").run(pool).await?;
167
168        Ok(())
169    }
170
171    async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> {
172        sqlx::query(
173            "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, deleted_at)
174                values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
175        )
176        .bind(h.id.0.as_str())
177        .bind(h.timestamp.unix_timestamp_nanos() as i64)
178        .bind(h.duration)
179        .bind(h.exit)
180        .bind(h.command.as_str())
181        .bind(h.cwd.as_str())
182        .bind(h.session.as_str())
183        .bind(h.hostname.as_str())
184        .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
185        .execute(&mut **tx)
186        .await?;
187
188        Ok(())
189    }
190
191    async fn delete_row_raw(
192        tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
193        id: HistoryId,
194    ) -> Result<()> {
195        sqlx::query("delete from history where id = ?1")
196            .bind(id.0.as_str())
197            .execute(&mut **tx)
198            .await?;
199
200        Ok(())
201    }
202
203    fn query_history(row: SqliteRow) -> History {
204        let deleted_at: Option<i64> = row.get("deleted_at");
205
206        History::from_db()
207            .id(row.get("id"))
208            .timestamp(
209                OffsetDateTime::from_unix_timestamp_nanos(row.get::<i64, _>("timestamp") as i128)
210                    .unwrap(),
211            )
212            .duration(row.get("duration"))
213            .exit(row.get("exit"))
214            .command(row.get("command"))
215            .cwd(row.get("cwd"))
216            .session(row.get("session"))
217            .hostname(row.get("hostname"))
218            .deleted_at(
219                deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()),
220            )
221            .build()
222            .into()
223    }
224}
225
226#[async_trait]
227impl Database for Sqlite {
228    async fn save(&self, h: &History) -> Result<()> {
229        debug!("saving history to sqlite");
230        let mut tx = self.pool.begin().await?;
231        Self::save_raw(&mut tx, h).await?;
232        tx.commit().await?;
233
234        Ok(())
235    }
236
237    async fn save_bulk(&self, h: &[History]) -> Result<()> {
238        debug!("saving history to sqlite");
239
240        let mut tx = self.pool.begin().await?;
241
242        for i in h {
243            Self::save_raw(&mut tx, i).await?;
244        }
245
246        tx.commit().await?;
247
248        Ok(())
249    }
250
251    async fn load(&self, id: &str) -> Result<Option<History>> {
252        debug!("loading history item {}", id);
253
254        let res = sqlx::query("select * from history where id = ?1")
255            .bind(id)
256            .map(Self::query_history)
257            .fetch_optional(&self.pool)
258            .await?;
259
260        Ok(res)
261    }
262
263    async fn update(&self, h: &History) -> Result<()> {
264        debug!("updating sqlite history");
265
266        sqlx::query(
267            "update history
268                set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, deleted_at = ?9
269                where id = ?1",
270        )
271        .bind(h.id.0.as_str())
272        .bind(h.timestamp.unix_timestamp_nanos() as i64)
273        .bind(h.duration)
274        .bind(h.exit)
275        .bind(h.command.as_str())
276        .bind(h.cwd.as_str())
277        .bind(h.session.as_str())
278        .bind(h.hostname.as_str())
279        .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
280        .execute(&self.pool)
281        .await?;
282
283        Ok(())
284    }
285
286    // make a unique list, that only shows the *newest* version of things
287    async fn list(
288        &self,
289        filters: &[FilterMode],
290        context: &Context,
291        max: Option<usize>,
292        unique: bool,
293        include_deleted: bool,
294    ) -> Result<Vec<History>> {
295        debug!("listing history");
296
297        let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
298        query.field("*").order_desc("timestamp");
299        if !include_deleted {
300            query.and_where_is_null("deleted_at");
301        }
302
303        let git_root = if let Some(git_root) = context.git_root.clone() {
304            git_root.to_str().unwrap_or("/").to_string()
305        } else {
306            context.cwd.clone()
307        };
308
309        for filter in filters {
310            match filter {
311                FilterMode::Global => &mut query,
312                FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)),
313                FilterMode::Session => query.and_where_eq("session", quote(&context.session)),
314                FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)),
315                FilterMode::Workspace => query.and_where_like_left("cwd", &git_root),
316            };
317        }
318
319        if unique {
320            query.group_by("command").having("max(timestamp)");
321        }
322
323        if let Some(max) = max {
324            query.limit(max);
325        }
326
327        let query = query.sql().expect("bug in list query. please report");
328
329        let res = sqlx::query(&query)
330            .map(Self::query_history)
331            .fetch_all(&self.pool)
332            .await?;
333
334        Ok(res)
335    }
336
337    async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> {
338        debug!("listing history from {:?} to {:?}", from, to);
339
340        let res = sqlx::query(
341            "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc",
342        )
343        .bind(from.unix_timestamp_nanos() as i64)
344        .bind(to.unix_timestamp_nanos() as i64)
345            .map(Self::query_history)
346        .fetch_all(&self.pool)
347        .await?;
348
349        Ok(res)
350    }
351
352    async fn last(&self) -> Result<Option<History>> {
353        let res = sqlx::query(
354            "select * from history where duration >= 0 order by timestamp desc limit 1",
355        )
356        .map(Self::query_history)
357        .fetch_optional(&self.pool)
358        .await?;
359
360        Ok(res)
361    }
362
363    async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> {
364        let res = sqlx::query(
365            "select * from history where timestamp < ?1 order by timestamp desc limit ?2",
366        )
367        .bind(timestamp.unix_timestamp_nanos() as i64)
368        .bind(count)
369        .map(Self::query_history)
370        .fetch_all(&self.pool)
371        .await?;
372
373        Ok(res)
374    }
375
376    async fn deleted(&self) -> Result<Vec<History>> {
377        let res = sqlx::query("select * from history where deleted_at is not null")
378            .map(Self::query_history)
379            .fetch_all(&self.pool)
380            .await?;
381
382        Ok(res)
383    }
384
385    async fn history_count(&self, include_deleted: bool) -> Result<i64> {
386        let query = if include_deleted {
387            "select count(1) from history"
388        } else {
389            "select count(1) from history where deleted_at is null"
390        };
391
392        let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?;
393        Ok(res.0)
394    }
395
396    async fn search(
397        &self,
398        search_mode: SearchMode,
399        filter: FilterMode,
400        context: &Context,
401        query: &str,
402        filter_options: OptFilters,
403    ) -> Result<Vec<History>> {
404        let mut sql = SqlBuilder::select_from("history");
405
406        sql.group_by("command").having("max(timestamp)");
407
408        if let Some(limit) = filter_options.limit {
409            sql.limit(limit);
410        }
411
412        if let Some(offset) = filter_options.offset {
413            sql.offset(offset);
414        }
415
416        if filter_options.reverse {
417            sql.order_asc("timestamp");
418        } else {
419            sql.order_desc("timestamp");
420        }
421
422        let git_root = if let Some(git_root) = context.git_root.clone() {
423            git_root.to_str().unwrap_or("/").to_string()
424        } else {
425            context.cwd.clone()
426        };
427
428        match filter {
429            FilterMode::Global => &mut sql,
430            FilterMode::Host => {
431                sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase()))
432            }
433            FilterMode::Session => sql.and_where_eq("session", quote(&context.session)),
434            FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)),
435            FilterMode::Workspace => sql.and_where_like_left("cwd", git_root),
436        };
437
438        let orig_query = query;
439
440        let mut regexes = Vec::new();
441        match search_mode {
442            SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")),
443            _ => {
444                let mut is_or = false;
445                let mut regex = None;
446                for part in query.split_inclusive(' ') {
447                    let query_part: Cow<str> = match (&mut regex, part.starts_with("r/")) {
448                        (None, false) => {
449                            if part.trim_end().is_empty() {
450                                continue;
451                            }
452                            Cow::Owned(part.trim_end().replace('*', "%")) // allow wildcard char
453                        }
454                        (None, true) => {
455                            if part[2..].trim_end().ends_with('/') {
456                                let end_pos = part.trim_end().len() - 1;
457                                regexes.push(String::from(&part[2..end_pos]));
458                            } else {
459                                regex = Some(String::from(&part[2..]));
460                            }
461                            continue;
462                        }
463                        (Some(r), _) => {
464                            if part.trim_end().ends_with('/') {
465                                let end_pos = part.trim_end().len() - 1;
466                                r.push_str(&part.trim_end()[..end_pos]);
467                                regexes.push(regex.take().unwrap());
468                            } else {
469                                r.push_str(part);
470                            }
471                            continue;
472                        }
473                    };
474
475                    // TODO smart case mode could be made configurable like in fzf
476                    let (is_glob, glob) = if query_part.contains(char::is_uppercase) {
477                        (true, "*")
478                    } else {
479                        (false, "%")
480                    };
481
482                    let (is_inverse, query_part) = match query_part.strip_prefix('!') {
483                        Some(stripped) => (true, Cow::Borrowed(stripped)),
484                        None => (false, query_part),
485                    };
486
487                    #[allow(clippy::if_same_then_else)]
488                    let param = if query_part == "|" {
489                        if !is_or {
490                            is_or = true;
491                            continue;
492                        } else {
493                            format!("{glob}|{glob}")
494                        }
495                    } else if let Some(term) = query_part.strip_prefix('^') {
496                        format!("{term}{glob}")
497                    } else if let Some(term) = query_part.strip_suffix('$') {
498                        format!("{glob}{term}")
499                    } else if let Some(term) = query_part.strip_prefix('\'') {
500                        format!("{glob}{term}{glob}")
501                    } else if is_inverse {
502                        format!("{glob}{query_part}{glob}")
503                    } else if search_mode == SearchMode::FullText {
504                        format!("{glob}{query_part}{glob}")
505                    } else {
506                        query_part.split("").join(glob)
507                    };
508
509                    sql.fuzzy_condition("command", param, is_inverse, is_glob, is_or);
510                    is_or = false;
511                }
512                if let Some(r) = regex {
513                    regexes.push(r);
514                }
515
516                &mut sql
517            }
518        };
519
520        for regex in regexes {
521            sql.and_where("command regexp ?".bind(&regex));
522        }
523
524        filter_options
525            .exit
526            .map(|exit| sql.and_where_eq("exit", exit));
527
528        filter_options
529            .exclude_exit
530            .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit));
531
532        filter_options
533            .cwd
534            .map(|cwd| sql.and_where_eq("cwd", quote(cwd)));
535
536        filter_options
537            .exclude_cwd
538            .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd)));
539
540        filter_options.before.map(|before| {
541            interim::parse_date_string(
542                before.as_str(),
543                OffsetDateTime::now_utc(),
544                interim::Dialect::Uk,
545            )
546            .map(|before| {
547                sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64))
548            })
549        });
550
551        filter_options.after.map(|after| {
552            interim::parse_date_string(
553                after.as_str(),
554                OffsetDateTime::now_utc(),
555                interim::Dialect::Uk,
556            )
557            .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64)))
558        });
559
560        sql.and_where_is_null("deleted_at");
561
562        let query = sql.sql().expect("bug in search query. please report");
563
564        let res = sqlx::query(&query)
565            .map(Self::query_history)
566            .fetch_all(&self.pool)
567            .await?;
568
569        Ok(ordering::reorder_fuzzy(search_mode, orig_query, res))
570    }
571
572    async fn query_history(&self, query: &str) -> Result<Vec<History>> {
573        let res = sqlx::query(query)
574            .map(Self::query_history)
575            .fetch_all(&self.pool)
576            .await?;
577
578        Ok(res)
579    }
580
581    async fn all_with_count(&self) -> Result<Vec<(History, i32)>> {
582        debug!("listing history");
583
584        let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
585
586        query
587            .fields(&[
588                "id",
589                "max(timestamp) as timestamp",
590                "max(duration) as duration",
591                "exit",
592                "command",
593                "deleted_at",
594                "group_concat(cwd, ':') as cwd",
595                "group_concat(session) as session",
596                "group_concat(hostname, ',') as hostname",
597                "count(*) as count",
598            ])
599            .group_by("command")
600            .group_by("exit")
601            .and_where("deleted_at is null")
602            .order_desc("timestamp");
603
604        let query = query.sql().expect("bug in list query. please report");
605
606        let res = sqlx::query(&query)
607            .map(|row: SqliteRow| {
608                let count: i32 = row.get("count");
609                (Self::query_history(row), count)
610            })
611            .fetch_all(&self.pool)
612            .await?;
613
614        Ok(res)
615    }
616
617    // deleted_at doesn't mean the actual time that the user deleted it,
618    // but the time that the system marks it as deleted
619    async fn delete(&self, mut h: History) -> Result<()> {
620        let now = OffsetDateTime::now_utc();
621        h.command = rand::thread_rng()
622            .sample_iter(&Alphanumeric)
623            .take(32)
624            .map(char::from)
625            .collect(); // overwrite with random string
626        h.deleted_at = Some(now); // delete it
627
628        self.update(&h).await?; // save it
629
630        Ok(())
631    }
632
633    async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> {
634        let mut tx = self.pool.begin().await?;
635
636        for id in ids {
637            Self::delete_row_raw(&mut tx, id.clone()).await?;
638        }
639
640        tx.commit().await?;
641
642        Ok(())
643    }
644
645    async fn stats(&self, h: &History) -> Result<HistoryStats> {
646        // We select the previous in the session by time
647        let mut prev = SqlBuilder::select_from("history");
648        prev.field("*")
649            .and_where("timestamp < ?1")
650            .and_where("session = ?2")
651            .order_by("timestamp", true)
652            .limit(1);
653
654        let mut next = SqlBuilder::select_from("history");
655        next.field("*")
656            .and_where("timestamp > ?1")
657            .and_where("session = ?2")
658            .order_by("timestamp", false)
659            .limit(1);
660
661        let mut total = SqlBuilder::select_from("history");
662        total.field("count(1)").and_where("command = ?1");
663
664        let mut average = SqlBuilder::select_from("history");
665        average.field("avg(duration)").and_where("command = ?1");
666
667        let mut exits = SqlBuilder::select_from("history");
668        exits
669            .fields(&["exit", "count(1) as count"])
670            .and_where("command = ?1")
671            .group_by("exit");
672
673        // rewrite the following with sqlbuilder
674        let mut day_of_week = SqlBuilder::select_from("history");
675        day_of_week
676            .fields(&[
677                "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week",
678                "count(1) as count",
679            ])
680            .and_where("command = ?1")
681            .group_by("day_of_week");
682
683        // Intentionally format the string with 01 hardcoded. We want the average runtime for the
684        // _entire month_, but will later parse it as a datetime for sorting
685        // Sqlite has no datetime so we cannot do it there, and otherwise sorting will just be a
686        // string sort, which won't be correct.
687        let mut duration_over_time = SqlBuilder::select_from("history");
688        duration_over_time
689            .fields(&[
690                "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year",
691                "avg(duration) as duration",
692            ])
693            .and_where("command = ?1")
694            .group_by("month_year")
695            .having("duration > 0");
696
697        let prev = prev.sql().expect("issue in stats previous query");
698        let next = next.sql().expect("issue in stats next query");
699        let total = total.sql().expect("issue in stats average query");
700        let average = average.sql().expect("issue in stats previous query");
701        let exits = exits.sql().expect("issue in stats exits query");
702        let day_of_week = day_of_week.sql().expect("issue in stats day of week query");
703        let duration_over_time = duration_over_time
704            .sql()
705            .expect("issue in stats duration over time query");
706
707        let prev = sqlx::query(&prev)
708            .bind(h.timestamp.unix_timestamp_nanos() as i64)
709            .bind(&h.session)
710            .map(Self::query_history)
711            .fetch_optional(&self.pool)
712            .await?;
713
714        let next = sqlx::query(&next)
715            .bind(h.timestamp.unix_timestamp_nanos() as i64)
716            .bind(&h.session)
717            .map(Self::query_history)
718            .fetch_optional(&self.pool)
719            .await?;
720
721        let total: (i64,) = sqlx::query_as(&total)
722            .bind(&h.command)
723            .fetch_one(&self.pool)
724            .await?;
725
726        let average: (f64,) = sqlx::query_as(&average)
727            .bind(&h.command)
728            .fetch_one(&self.pool)
729            .await?;
730
731        let exits: Vec<(i64, i64)> = sqlx::query_as(&exits)
732            .bind(&h.command)
733            .fetch_all(&self.pool)
734            .await?;
735
736        let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week)
737            .bind(&h.command)
738            .fetch_all(&self.pool)
739            .await?;
740
741        let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time)
742            .bind(&h.command)
743            .fetch_all(&self.pool)
744            .await?;
745
746        let duration_over_time = duration_over_time
747            .iter()
748            .map(|f| (f.0.clone(), f.1.round() as i64))
749            .collect();
750
751        Ok(HistoryStats {
752            next,
753            previous: prev,
754            total: total.0 as u64,
755            average_duration: average.0 as u64,
756            exits,
757            day_of_week,
758            duration_over_time,
759        })
760    }
761}
762
763trait SqlBuilderExt {
764    fn fuzzy_condition<S: ToString, T: ToString>(
765        &mut self,
766        field: S,
767        mask: T,
768        inverse: bool,
769        glob: bool,
770        is_or: bool,
771    ) -> &mut Self;
772}
773
774impl SqlBuilderExt for SqlBuilder {
775    /// adapted from the sql-builder *like functions
776    fn fuzzy_condition<S: ToString, T: ToString>(
777        &mut self,
778        field: S,
779        mask: T,
780        inverse: bool,
781        glob: bool,
782        is_or: bool,
783    ) -> &mut Self {
784        let mut cond = field.to_string();
785        if inverse {
786            cond.push_str(" NOT");
787        }
788        if glob {
789            cond.push_str(" GLOB '");
790        } else {
791            cond.push_str(" LIKE '");
792        }
793        cond.push_str(&esc(mask.to_string()));
794        cond.push('\'');
795        if is_or {
796            self.or_where(cond)
797        } else {
798            self.and_where(cond)
799        }
800    }
801}
802
803#[cfg(test)]
804mod test {
805    use crate::settings::test_local_timeout;
806
807    use super::*;
808    use std::time::{Duration, Instant};
809
810    async fn assert_search_eq<'a>(
811        db: &impl Database,
812        mode: SearchMode,
813        filter_mode: FilterMode,
814        query: &str,
815        expected: usize,
816    ) -> Result<Vec<History>> {
817        let context = Context {
818            hostname: "test:host".to_string(),
819            session: "beepboopiamasession".to_string(),
820            cwd: "/home/ellie".to_string(),
821            host_id: "test-host".to_string(),
822            git_root: None,
823        };
824
825        let results = db
826            .search(
827                mode,
828                filter_mode,
829                &context,
830                query,
831                OptFilters {
832                    ..Default::default()
833                },
834            )
835            .await?;
836
837        assert_eq!(
838            results.len(),
839            expected,
840            "query \"{}\", commands: {:?}",
841            query,
842            results.iter().map(|a| &a.command).collect::<Vec<&String>>()
843        );
844        Ok(results)
845    }
846
847    async fn assert_search_commands(
848        db: &impl Database,
849        mode: SearchMode,
850        filter_mode: FilterMode,
851        query: &str,
852        expected_commands: Vec<&str>,
853    ) {
854        let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len())
855            .await
856            .unwrap();
857        let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect();
858        assert_eq!(commands, expected_commands);
859    }
860
861    async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> {
862        let mut captured: History = History::capture()
863            .timestamp(OffsetDateTime::now_utc())
864            .command(cmd)
865            .cwd("/home/ellie")
866            .build()
867            .into();
868
869        captured.exit = 0;
870        captured.duration = 1;
871        captured.session = "beep boop".to_string();
872        captured.hostname = "booop".to_string();
873
874        db.save(&captured).await
875    }
876
877    #[tokio::test(flavor = "multi_thread")]
878    async fn test_search_prefix() {
879        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
880            .await
881            .unwrap();
882        new_history_item(&mut db, "ls /home/ellie").await.unwrap();
883
884        assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1)
885            .await
886            .unwrap();
887        assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0)
888            .await
889            .unwrap();
890        assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls  ", 0)
891            .await
892            .unwrap();
893    }
894
895    #[tokio::test(flavor = "multi_thread")]
896    async fn test_search_fulltext() {
897        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
898            .await
899            .unwrap();
900        new_history_item(&mut db, "ls /home/ellie").await.unwrap();
901
902        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1)
903            .await
904            .unwrap();
905        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1)
906            .await
907            .unwrap();
908        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1)
909            .await
910            .unwrap();
911        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0)
912            .await
913            .unwrap();
914
915        // regex
916        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1)
917            .await
918            .unwrap();
919        assert_search_eq(
920            &db,
921            SearchMode::FullText,
922            FilterMode::Global,
923            "r/ls / ie$",
924            1,
925        )
926        .await
927        .unwrap();
928        assert_search_eq(
929            &db,
930            SearchMode::FullText,
931            FilterMode::Global,
932            "r/ls / !ie",
933            0,
934        )
935        .await
936        .unwrap();
937        assert_search_eq(
938            &db,
939            SearchMode::FullText,
940            FilterMode::Global,
941            "meow r/ls/",
942            0,
943        )
944        .await
945        .unwrap();
946        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1)
947            .await
948            .unwrap();
949        assert_search_eq(
950            &db,
951            SearchMode::FullText,
952            FilterMode::Global,
953            "r//home//",
954            1,
955        )
956        .await
957        .unwrap();
958        assert_search_eq(
959            &db,
960            SearchMode::FullText,
961            FilterMode::Global,
962            "r//home///",
963            0,
964        )
965        .await
966        .unwrap();
967        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0)
968            .await
969            .unwrap();
970        assert_search_eq(
971            &db,
972            SearchMode::FullText,
973            FilterMode::Global,
974            "r/home.*e",
975            1,
976        )
977        .await
978        .unwrap();
979    }
980
981    #[tokio::test(flavor = "multi_thread")]
982    async fn test_search_fuzzy() {
983        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
984            .await
985            .unwrap();
986        new_history_item(&mut db, "ls /home/ellie").await.unwrap();
987        new_history_item(&mut db, "ls /home/frank").await.unwrap();
988        new_history_item(&mut db, "cd /home/Ellie").await.unwrap();
989        new_history_item(&mut db, "/home/ellie/.bin/rustup")
990            .await
991            .unwrap();
992
993        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3)
994            .await
995            .unwrap();
996        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2)
997            .await
998            .unwrap();
999        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2)
1000            .await
1001            .unwrap();
1002        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3)
1003            .await
1004            .unwrap();
1005        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0)
1006            .await
1007            .unwrap();
1008        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0)
1009            .await
1010            .unwrap();
1011        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1)
1012            .await
1013            .unwrap();
1014        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4)
1015            .await
1016            .unwrap();
1017
1018        // single term operators
1019        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2)
1020            .await
1021            .unwrap();
1022        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2)
1023            .await
1024            .unwrap();
1025        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2)
1026            .await
1027            .unwrap();
1028        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2)
1029            .await
1030            .unwrap();
1031        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1)
1032            .await
1033            .unwrap();
1034        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2)
1035            .await
1036            .unwrap();
1037
1038        // multiple terms
1039        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1)
1040            .await
1041            .unwrap();
1042        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1)
1043            .await
1044            .unwrap();
1045        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2)
1046            .await
1047            .unwrap();
1048        assert_search_eq(
1049            &db,
1050            SearchMode::Fuzzy,
1051            FilterMode::Global,
1052            "'frank | 'rustup",
1053            2,
1054        )
1055        .await
1056        .unwrap();
1057        assert_search_eq(
1058            &db,
1059            SearchMode::Fuzzy,
1060            FilterMode::Global,
1061            "'frank | 'rustup 'ls",
1062            1,
1063        )
1064        .await
1065        .unwrap();
1066
1067        // case matching
1068        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1)
1069            .await
1070            .unwrap();
1071
1072        // regex
1073        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2)
1074            .await
1075            .unwrap();
1076        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3)
1077            .await
1078            .unwrap();
1079        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1)
1080            .await
1081            .unwrap();
1082    }
1083
1084    #[tokio::test(flavor = "multi_thread")]
1085    async fn test_search_reordered_fuzzy() {
1086        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1087            .await
1088            .unwrap();
1089        // test ordering of results: we should choose the first, even though it happened longer ago.
1090
1091        new_history_item(&mut db, "curl").await.unwrap();
1092        new_history_item(&mut db, "corburl").await.unwrap();
1093
1094        // if fuzzy reordering is on, it should come back in a more sensible order
1095        assert_search_commands(
1096            &db,
1097            SearchMode::Fuzzy,
1098            FilterMode::Global,
1099            "curl",
1100            vec!["curl", "corburl"],
1101        )
1102        .await;
1103
1104        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0)
1105            .await
1106            .unwrap();
1107        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2)
1108            .await
1109            .unwrap();
1110    }
1111
1112    #[tokio::test(flavor = "multi_thread")]
1113    async fn test_search_bench_dupes() {
1114        let context = Context {
1115            hostname: "test:host".to_string(),
1116            session: "beepboopiamasession".to_string(),
1117            cwd: "/home/ellie".to_string(),
1118            host_id: "test-host".to_string(),
1119            git_root: None,
1120        };
1121
1122        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1123            .await
1124            .unwrap();
1125        for _i in 1..10000 {
1126            new_history_item(&mut db, "i am a duplicated command")
1127                .await
1128                .unwrap();
1129        }
1130        let start = Instant::now();
1131        let _results = db
1132            .search(
1133                SearchMode::Fuzzy,
1134                FilterMode::Global,
1135                &context,
1136                "",
1137                OptFilters {
1138                    ..Default::default()
1139                },
1140            )
1141            .await
1142            .unwrap();
1143        let duration = start.elapsed();
1144
1145        assert!(duration < Duration::from_secs(15));
1146    }
1147}