atuin_server_postgres/
lib.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::ops::Range;
4
5use async_trait::async_trait;
6use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
7use atuin_common::utils::crypto_random_string;
8use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
9use atuin_server_database::{Database, DbError, DbResult};
10use futures_util::TryStreamExt;
11use metrics::counter;
12use serde::{Deserialize, Serialize};
13use sqlx::postgres::PgPoolOptions;
14use sqlx::Row;
15
16use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
17use tracing::{instrument, trace};
18use uuid::Uuid;
19use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
20
21mod wrappers;
22
23const MIN_PG_VERSION: u32 = 14;
24
25#[derive(Clone)]
26pub struct Postgres {
27    pool: sqlx::Pool<sqlx::postgres::Postgres>,
28}
29
30#[derive(Clone, Deserialize, Serialize)]
31pub struct PostgresSettings {
32    pub db_uri: String,
33}
34
35// Do our best to redact passwords so they're not logged in the event of an error.
36impl Debug for PostgresSettings {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        let redacted_uri = url::Url::parse(&self.db_uri)
39            .map(|mut url| {
40                let _ = url.set_password(Some("****"));
41                url.to_string()
42            })
43            .unwrap_or(self.db_uri.clone());
44        f.debug_struct("PostgresSettings")
45            .field("db_uri", &redacted_uri)
46            .finish()
47    }
48}
49
50fn fix_error(error: sqlx::Error) -> DbError {
51    match error {
52        sqlx::Error::RowNotFound => DbError::NotFound,
53        error => DbError::Other(error.into()),
54    }
55}
56
57#[async_trait]
58impl Database for Postgres {
59    type Settings = PostgresSettings;
60    async fn new(settings: &PostgresSettings) -> DbResult<Self> {
61        let pool = PgPoolOptions::new()
62            .max_connections(100)
63            .connect(settings.db_uri.as_str())
64            .await
65            .map_err(fix_error)?;
66
67        // Call server_version_num to get the DB server's major version number
68        // The call returns None for servers older than 8.x.
69        let pg_major_version: u32 = pool
70            .acquire()
71            .await
72            .map_err(fix_error)?
73            .server_version_num()
74            .ok_or(DbError::Other(eyre::Report::msg(
75                "could not get PostgreSQL version",
76            )))?
77            / 10000;
78
79        if pg_major_version < MIN_PG_VERSION {
80            return Err(DbError::Other(eyre::Report::msg(format!(
81                "unsupported PostgreSQL version {}, minimum required is {}",
82                pg_major_version, MIN_PG_VERSION
83            ))));
84        }
85
86        sqlx::migrate!("./migrations")
87            .run(&pool)
88            .await
89            .map_err(|error| DbError::Other(error.into()))?;
90
91        Ok(Self { pool })
92    }
93
94    #[instrument(skip_all)]
95    async fn get_session(&self, token: &str) -> DbResult<Session> {
96        sqlx::query_as("select id, user_id, token from sessions where token = $1")
97            .bind(token)
98            .fetch_one(&self.pool)
99            .await
100            .map_err(fix_error)
101            .map(|DbSession(session)| session)
102    }
103
104    #[instrument(skip_all)]
105    async fn get_user(&self, username: &str) -> DbResult<User> {
106        sqlx::query_as(
107            "select id, username, email, password, verified_at from users where username = $1",
108        )
109        .bind(username)
110        .fetch_one(&self.pool)
111        .await
112        .map_err(fix_error)
113        .map(|DbUser(user)| user)
114    }
115
116    #[instrument(skip_all)]
117    async fn user_verified(&self, id: i64) -> DbResult<bool> {
118        let res: (bool,) =
119            sqlx::query_as("select verified_at is not null from users where id = $1")
120                .bind(id)
121                .fetch_one(&self.pool)
122                .await
123                .map_err(fix_error)?;
124
125        Ok(res.0)
126    }
127
128    #[instrument(skip_all)]
129    async fn verify_user(&self, id: i64) -> DbResult<()> {
130        sqlx::query(
131            "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1",
132        )
133        .bind(id)
134        .execute(&self.pool)
135        .await
136        .map_err(fix_error)?;
137
138        Ok(())
139    }
140
141    /// Return a valid verification token for the user
142    /// If the user does not have any token, create one, insert it, and return
143    /// If the user has a token, but it's invalid, delete it, create a new one, return
144    /// If the user already has a valid token, return it
145    #[instrument(skip_all)]
146    async fn user_verification_token(&self, id: i64) -> DbResult<String> {
147        const TOKEN_VALID_MINUTES: i64 = 15;
148
149        // First we check if there is a verification token
150        let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as(
151            "select token, valid_until from user_verification_token where user_id = $1",
152        )
153        .bind(id)
154        .fetch_optional(&self.pool)
155        .await
156        .map_err(fix_error)?;
157
158        let token = if let Some((token, valid_until)) = token {
159            trace!("Token for user {id} valid until {valid_until}");
160
161            // We have a token, AND it's still valid
162            if valid_until > time::OffsetDateTime::now_utc() {
163                token
164            } else {
165                // token has expired. generate a new one, return it
166                let token = crypto_random_string::<24>();
167
168                sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1")
169                    .bind(id)
170                    .bind(&token)
171                    .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
172                    .execute(&self.pool)
173                    .await
174                    .map_err(fix_error)?;
175
176                token
177            }
178        } else {
179            // No token in the database! Generate one, insert it
180            let token = crypto_random_string::<24>();
181
182            sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)")
183                .bind(id)
184                .bind(&token)
185                .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
186                .execute(&self.pool)
187                .await
188                .map_err(fix_error)?;
189
190            token
191        };
192
193        Ok(token)
194    }
195
196    #[instrument(skip_all)]
197    async fn get_session_user(&self, token: &str) -> DbResult<User> {
198        sqlx::query_as(
199            "select users.id, users.username, users.email, users.password, users.verified_at from users 
200            inner join sessions 
201            on users.id = sessions.user_id 
202            and sessions.token = $1",
203        )
204        .bind(token)
205        .fetch_one(&self.pool)
206        .await
207        .map_err(fix_error)
208        .map(|DbUser(user)| user)
209    }
210
211    #[instrument(skip_all)]
212    async fn count_history(&self, user: &User) -> DbResult<i64> {
213        // The cache is new, and the user might not yet have a cache value.
214        // They will have one as soon as they post up some new history, but handle that
215        // edge case.
216
217        let res: (i64,) = sqlx::query_as(
218            "select count(1) from history
219            where user_id = $1",
220        )
221        .bind(user.id)
222        .fetch_one(&self.pool)
223        .await
224        .map_err(fix_error)?;
225
226        Ok(res.0)
227    }
228
229    #[instrument(skip_all)]
230    async fn total_history(&self) -> DbResult<i64> {
231        // The cache is new, and the user might not yet have a cache value.
232        // They will have one as soon as they post up some new history, but handle that
233        // edge case.
234
235        let res: (i64,) = sqlx::query_as("select sum(total) from total_history_count_user")
236            .fetch_optional(&self.pool)
237            .await
238            .map_err(fix_error)?
239            .unwrap_or((0,));
240
241        Ok(res.0)
242    }
243
244    #[instrument(skip_all)]
245    async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
246        let res: (i32,) = sqlx::query_as(
247            "select total from total_history_count_user
248            where user_id = $1",
249        )
250        .bind(user.id)
251        .fetch_one(&self.pool)
252        .await
253        .map_err(fix_error)?;
254
255        Ok(res.0 as i64)
256    }
257
258    async fn delete_store(&self, user: &User) -> DbResult<()> {
259        sqlx::query(
260            "delete from store
261            where user_id = $1",
262        )
263        .bind(user.id)
264        .execute(&self.pool)
265        .await
266        .map_err(fix_error)?;
267
268        Ok(())
269    }
270
271    async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
272        sqlx::query(
273            "update history
274            set deleted_at = $3
275            where user_id = $1
276            and client_id = $2
277            and deleted_at is null", // don't just keep setting it
278        )
279        .bind(user.id)
280        .bind(id)
281        .bind(OffsetDateTime::now_utc())
282        .fetch_all(&self.pool)
283        .await
284        .map_err(fix_error)?;
285
286        Ok(())
287    }
288
289    #[instrument(skip_all)]
290    async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
291        // The cache is new, and the user might not yet have a cache value.
292        // They will have one as soon as they post up some new history, but handle that
293        // edge case.
294
295        let res = sqlx::query(
296            "select client_id from history 
297            where user_id = $1
298            and deleted_at is not null",
299        )
300        .bind(user.id)
301        .fetch_all(&self.pool)
302        .await
303        .map_err(fix_error)?;
304
305        let res = res
306            .iter()
307            .map(|row| row.get::<String, _>("client_id"))
308            .collect();
309
310        Ok(res)
311    }
312
313    #[instrument(skip_all)]
314    async fn count_history_range(
315        &self,
316        user: &User,
317        range: Range<OffsetDateTime>,
318    ) -> DbResult<i64> {
319        let res: (i64,) = sqlx::query_as(
320            "select count(1) from history
321            where user_id = $1
322            and timestamp >= $2::date
323            and timestamp < $3::date",
324        )
325        .bind(user.id)
326        .bind(into_utc(range.start))
327        .bind(into_utc(range.end))
328        .fetch_one(&self.pool)
329        .await
330        .map_err(fix_error)?;
331
332        Ok(res.0)
333    }
334
335    #[instrument(skip_all)]
336    async fn list_history(
337        &self,
338        user: &User,
339        created_after: OffsetDateTime,
340        since: OffsetDateTime,
341        host: &str,
342        page_size: i64,
343    ) -> DbResult<Vec<History>> {
344        let res = sqlx::query_as(
345            "select id, client_id, user_id, hostname, timestamp, data, created_at from history 
346            where user_id = $1
347            and hostname != $2
348            and created_at >= $3
349            and timestamp >= $4
350            order by timestamp asc
351            limit $5",
352        )
353        .bind(user.id)
354        .bind(host)
355        .bind(into_utc(created_after))
356        .bind(into_utc(since))
357        .bind(page_size)
358        .fetch(&self.pool)
359        .map_ok(|DbHistory(h)| h)
360        .try_collect()
361        .await
362        .map_err(fix_error)?;
363
364        Ok(res)
365    }
366
367    #[instrument(skip_all)]
368    async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
369        let mut tx = self.pool.begin().await.map_err(fix_error)?;
370
371        for i in history {
372            let client_id: &str = &i.client_id;
373            let hostname: &str = &i.hostname;
374            let data: &str = &i.data;
375
376            sqlx::query(
377                "insert into history
378                    (client_id, user_id, hostname, timestamp, data) 
379                values ($1, $2, $3, $4, $5)
380                on conflict do nothing
381                ",
382            )
383            .bind(client_id)
384            .bind(i.user_id)
385            .bind(hostname)
386            .bind(i.timestamp)
387            .bind(data)
388            .execute(&mut *tx)
389            .await
390            .map_err(fix_error)?;
391        }
392
393        tx.commit().await.map_err(fix_error)?;
394
395        Ok(())
396    }
397
398    #[instrument(skip_all)]
399    async fn delete_user(&self, u: &User) -> DbResult<()> {
400        sqlx::query("delete from sessions where user_id = $1")
401            .bind(u.id)
402            .execute(&self.pool)
403            .await
404            .map_err(fix_error)?;
405
406        sqlx::query("delete from history where user_id = $1")
407            .bind(u.id)
408            .execute(&self.pool)
409            .await
410            .map_err(fix_error)?;
411
412        sqlx::query("delete from store where user_id = $1")
413            .bind(u.id)
414            .execute(&self.pool)
415            .await
416            .map_err(fix_error)?;
417
418        sqlx::query("delete from user_verification_token where user_id = $1")
419            .bind(u.id)
420            .execute(&self.pool)
421            .await
422            .map_err(fix_error)?;
423
424        sqlx::query("delete from total_history_count_user where user_id = $1")
425            .bind(u.id)
426            .execute(&self.pool)
427            .await
428            .map_err(fix_error)?;
429
430        sqlx::query("delete from users where id = $1")
431            .bind(u.id)
432            .execute(&self.pool)
433            .await
434            .map_err(fix_error)?;
435
436        Ok(())
437    }
438
439    #[instrument(skip_all)]
440    async fn update_user_password(&self, user: &User) -> DbResult<()> {
441        sqlx::query(
442            "update users
443            set password = $1
444            where id = $2",
445        )
446        .bind(&user.password)
447        .bind(user.id)
448        .execute(&self.pool)
449        .await
450        .map_err(fix_error)?;
451
452        Ok(())
453    }
454
455    #[instrument(skip_all)]
456    async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
457        let email: &str = &user.email;
458        let username: &str = &user.username;
459        let password: &str = &user.password;
460
461        let res: (i64,) = sqlx::query_as(
462            "insert into users
463                (username, email, password)
464            values($1, $2, $3)
465            returning id",
466        )
467        .bind(username)
468        .bind(email)
469        .bind(password)
470        .fetch_one(&self.pool)
471        .await
472        .map_err(fix_error)?;
473
474        Ok(res.0)
475    }
476
477    #[instrument(skip_all)]
478    async fn add_session(&self, session: &NewSession) -> DbResult<()> {
479        let token: &str = &session.token;
480
481        sqlx::query(
482            "insert into sessions
483                (user_id, token)
484            values($1, $2)",
485        )
486        .bind(session.user_id)
487        .bind(token)
488        .execute(&self.pool)
489        .await
490        .map_err(fix_error)?;
491
492        Ok(())
493    }
494
495    #[instrument(skip_all)]
496    async fn get_user_session(&self, u: &User) -> DbResult<Session> {
497        sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
498            .bind(u.id)
499            .fetch_one(&self.pool)
500            .await
501            .map_err(fix_error)
502            .map(|DbSession(session)| session)
503    }
504
505    #[instrument(skip_all)]
506    async fn oldest_history(&self, user: &User) -> DbResult<History> {
507        sqlx::query_as(
508            "select id, client_id, user_id, hostname, timestamp, data, created_at from history 
509            where user_id = $1
510            order by timestamp asc
511            limit 1",
512        )
513        .bind(user.id)
514        .fetch_one(&self.pool)
515        .await
516        .map_err(fix_error)
517        .map(|DbHistory(h)| h)
518    }
519
520    #[instrument(skip_all)]
521    async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
522        let mut tx = self.pool.begin().await.map_err(fix_error)?;
523
524        // We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max
525        // idx without having to make further database queries. Doing the query on this small
526        // amount of data should be much, much faster.
527        //
528        // Worst case, say we get this wrong. We end up caching data that isn't actually the max
529        // idx, so clients upload again. The cache logic can be verified with a sql query anyway :)
530
531        let mut heads = HashMap::<(HostId, &str), u64>::new();
532
533        for i in records {
534            let id = atuin_common::utils::uuid_v7();
535
536            sqlx::query(
537                "insert into store
538                    (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) 
539                values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
540                on conflict do nothing
541                ",
542            )
543            .bind(id)
544            .bind(i.id)
545            .bind(i.host.id)
546            .bind(i.idx as i64)
547            .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
548            .bind(&i.version)
549            .bind(&i.tag)
550            .bind(&i.data.data)
551            .bind(&i.data.content_encryption_key)
552            .bind(user.id)
553            .execute(&mut *tx)
554            .await
555            .map_err(fix_error)?;
556
557            // we're already iterating sooooo
558            heads
559                .entry((i.host.id, &i.tag))
560                .and_modify(|e| {
561                    if i.idx > *e {
562                        *e = i.idx
563                    }
564                })
565                .or_insert(i.idx);
566        }
567
568        // we've built the map of heads for this push, so commit it to the database
569        for ((host, tag), idx) in heads {
570            sqlx::query(
571                "insert into store_idx_cache
572                    (user_id, host, tag, idx) 
573                values ($1, $2, $3, $4)
574                on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4)
575                ",
576            )
577            .bind(user.id)
578            .bind(host)
579            .bind(tag)
580            .bind(idx as i64)
581            .execute(&mut *tx)
582            .await
583            .map_err(fix_error)?;
584        }
585
586        tx.commit().await.map_err(fix_error)?;
587
588        Ok(())
589    }
590
591    #[instrument(skip_all)]
592    async fn next_records(
593        &self,
594        user: &User,
595        host: HostId,
596        tag: String,
597        start: Option<RecordIdx>,
598        count: u64,
599    ) -> DbResult<Vec<Record<EncryptedData>>> {
600        tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
601        let start = start.unwrap_or(0);
602
603        let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
604            "select client_id, host, idx, timestamp, version, tag, data, cek from store
605                    where user_id = $1
606                    and tag = $2
607                    and host = $3
608                    and idx >= $4
609                    order by idx asc
610                    limit $5",
611        )
612        .bind(user.id)
613        .bind(tag.clone())
614        .bind(host)
615        .bind(start as i64)
616        .bind(count as i64)
617        .fetch_all(&self.pool)
618        .await
619        .map_err(fix_error);
620
621        let ret = match records {
622            Ok(records) => {
623                let records: Vec<Record<EncryptedData>> = records
624                    .into_iter()
625                    .map(|f| {
626                        let record: Record<EncryptedData> = f.into();
627                        record
628                    })
629                    .collect();
630
631                records
632            }
633            Err(DbError::NotFound) => {
634                tracing::debug!("no records found in store: {:?}/{}", host, tag);
635                return Ok(vec![]);
636            }
637            Err(e) => return Err(e),
638        };
639
640        Ok(ret)
641    }
642
643    async fn status(&self, user: &User) -> DbResult<RecordStatus> {
644        const STATUS_SQL: &str =
645            "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
646
647        let mut res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
648            .bind(user.id)
649            .fetch_all(&self.pool)
650            .await
651            .map_err(fix_error)?;
652        res.sort();
653
654        // We're temporarily increasing latency in order to improve confidence in the cache
655        // If it runs for a few days, and we confirm that cached values are equal to realtime, we
656        // can replace realtime with cached.
657        //
658        // But let's check so sync doesn't do Weird Things.
659
660        let mut cached_res: Vec<(Uuid, String, i64)> =
661            sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
662                .bind(user.id)
663                .fetch_all(&self.pool)
664                .await
665                .map_err(fix_error)?;
666        cached_res.sort();
667
668        let mut status = RecordStatus::new();
669
670        let equal = res == cached_res;
671
672        if equal {
673            counter!("atuin_store_idx_cache_consistent", 1);
674        } else {
675            // log the values if we have an inconsistent cache
676            tracing::debug!(user = user.username, cache_match = equal, res = ?res, cached = ?cached_res, "record store index request");
677            counter!("atuin_store_idx_cache_inconsistent", 1);
678        };
679
680        for i in res.iter() {
681            status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64);
682        }
683
684        Ok(status)
685    }
686}
687
688fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
689    let x = x.to_offset(UtcOffset::UTC);
690    PrimitiveDateTime::new(x.date(), x.time())
691}
692
693#[cfg(test)]
694mod tests {
695    use time::macros::datetime;
696
697    use crate::into_utc;
698
699    #[test]
700    fn utc() {
701        let dt = datetime!(2023-09-26 15:11:02 +05:30);
702        assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02));
703        assert_eq!(into_utc(dt).assume_utc(), dt);
704
705        let dt = datetime!(2023-09-26 15:11:02 -07:00);
706        assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02));
707        assert_eq!(into_utc(dt).assume_utc(), dt);
708
709        let dt = datetime!(2023-09-26 15:11:02 +00:00);
710        assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02));
711        assert_eq!(into_utc(dt).assume_utc(), dt);
712    }
713}