1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::ops::Range;
4
5use async_trait::async_trait;
6use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
7use atuin_common::utils::crypto_random_string;
8use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
9use atuin_server_database::{Database, DbError, DbResult};
10use futures_util::TryStreamExt;
11use metrics::counter;
12use serde::{Deserialize, Serialize};
13use sqlx::postgres::PgPoolOptions;
14use sqlx::Row;
15
16use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
17use tracing::{instrument, trace};
18use uuid::Uuid;
19use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
20
21mod wrappers;
22
23const MIN_PG_VERSION: u32 = 14;
24
25#[derive(Clone)]
26pub struct Postgres {
27 pool: sqlx::Pool<sqlx::postgres::Postgres>,
28}
29
30#[derive(Clone, Deserialize, Serialize)]
31pub struct PostgresSettings {
32 pub db_uri: String,
33}
34
35impl Debug for PostgresSettings {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 let redacted_uri = url::Url::parse(&self.db_uri)
39 .map(|mut url| {
40 let _ = url.set_password(Some("****"));
41 url.to_string()
42 })
43 .unwrap_or(self.db_uri.clone());
44 f.debug_struct("PostgresSettings")
45 .field("db_uri", &redacted_uri)
46 .finish()
47 }
48}
49
50fn fix_error(error: sqlx::Error) -> DbError {
51 match error {
52 sqlx::Error::RowNotFound => DbError::NotFound,
53 error => DbError::Other(error.into()),
54 }
55}
56
57#[async_trait]
58impl Database for Postgres {
59 type Settings = PostgresSettings;
60 async fn new(settings: &PostgresSettings) -> DbResult<Self> {
61 let pool = PgPoolOptions::new()
62 .max_connections(100)
63 .connect(settings.db_uri.as_str())
64 .await
65 .map_err(fix_error)?;
66
67 let pg_major_version: u32 = pool
70 .acquire()
71 .await
72 .map_err(fix_error)?
73 .server_version_num()
74 .ok_or(DbError::Other(eyre::Report::msg(
75 "could not get PostgreSQL version",
76 )))?
77 / 10000;
78
79 if pg_major_version < MIN_PG_VERSION {
80 return Err(DbError::Other(eyre::Report::msg(format!(
81 "unsupported PostgreSQL version {}, minimum required is {}",
82 pg_major_version, MIN_PG_VERSION
83 ))));
84 }
85
86 sqlx::migrate!("./migrations")
87 .run(&pool)
88 .await
89 .map_err(|error| DbError::Other(error.into()))?;
90
91 Ok(Self { pool })
92 }
93
94 #[instrument(skip_all)]
95 async fn get_session(&self, token: &str) -> DbResult<Session> {
96 sqlx::query_as("select id, user_id, token from sessions where token = $1")
97 .bind(token)
98 .fetch_one(&self.pool)
99 .await
100 .map_err(fix_error)
101 .map(|DbSession(session)| session)
102 }
103
104 #[instrument(skip_all)]
105 async fn get_user(&self, username: &str) -> DbResult<User> {
106 sqlx::query_as(
107 "select id, username, email, password, verified_at from users where username = $1",
108 )
109 .bind(username)
110 .fetch_one(&self.pool)
111 .await
112 .map_err(fix_error)
113 .map(|DbUser(user)| user)
114 }
115
116 #[instrument(skip_all)]
117 async fn user_verified(&self, id: i64) -> DbResult<bool> {
118 let res: (bool,) =
119 sqlx::query_as("select verified_at is not null from users where id = $1")
120 .bind(id)
121 .fetch_one(&self.pool)
122 .await
123 .map_err(fix_error)?;
124
125 Ok(res.0)
126 }
127
128 #[instrument(skip_all)]
129 async fn verify_user(&self, id: i64) -> DbResult<()> {
130 sqlx::query(
131 "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1",
132 )
133 .bind(id)
134 .execute(&self.pool)
135 .await
136 .map_err(fix_error)?;
137
138 Ok(())
139 }
140
141 #[instrument(skip_all)]
146 async fn user_verification_token(&self, id: i64) -> DbResult<String> {
147 const TOKEN_VALID_MINUTES: i64 = 15;
148
149 let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as(
151 "select token, valid_until from user_verification_token where user_id = $1",
152 )
153 .bind(id)
154 .fetch_optional(&self.pool)
155 .await
156 .map_err(fix_error)?;
157
158 let token = if let Some((token, valid_until)) = token {
159 trace!("Token for user {id} valid until {valid_until}");
160
161 if valid_until > time::OffsetDateTime::now_utc() {
163 token
164 } else {
165 let token = crypto_random_string::<24>();
167
168 sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1")
169 .bind(id)
170 .bind(&token)
171 .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
172 .execute(&self.pool)
173 .await
174 .map_err(fix_error)?;
175
176 token
177 }
178 } else {
179 let token = crypto_random_string::<24>();
181
182 sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)")
183 .bind(id)
184 .bind(&token)
185 .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
186 .execute(&self.pool)
187 .await
188 .map_err(fix_error)?;
189
190 token
191 };
192
193 Ok(token)
194 }
195
196 #[instrument(skip_all)]
197 async fn get_session_user(&self, token: &str) -> DbResult<User> {
198 sqlx::query_as(
199 "select users.id, users.username, users.email, users.password, users.verified_at from users
200 inner join sessions
201 on users.id = sessions.user_id
202 and sessions.token = $1",
203 )
204 .bind(token)
205 .fetch_one(&self.pool)
206 .await
207 .map_err(fix_error)
208 .map(|DbUser(user)| user)
209 }
210
211 #[instrument(skip_all)]
212 async fn count_history(&self, user: &User) -> DbResult<i64> {
213 let res: (i64,) = sqlx::query_as(
218 "select count(1) from history
219 where user_id = $1",
220 )
221 .bind(user.id)
222 .fetch_one(&self.pool)
223 .await
224 .map_err(fix_error)?;
225
226 Ok(res.0)
227 }
228
229 #[instrument(skip_all)]
230 async fn total_history(&self) -> DbResult<i64> {
231 let res: (i64,) = sqlx::query_as("select sum(total) from total_history_count_user")
236 .fetch_optional(&self.pool)
237 .await
238 .map_err(fix_error)?
239 .unwrap_or((0,));
240
241 Ok(res.0)
242 }
243
244 #[instrument(skip_all)]
245 async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
246 let res: (i32,) = sqlx::query_as(
247 "select total from total_history_count_user
248 where user_id = $1",
249 )
250 .bind(user.id)
251 .fetch_one(&self.pool)
252 .await
253 .map_err(fix_error)?;
254
255 Ok(res.0 as i64)
256 }
257
258 async fn delete_store(&self, user: &User) -> DbResult<()> {
259 sqlx::query(
260 "delete from store
261 where user_id = $1",
262 )
263 .bind(user.id)
264 .execute(&self.pool)
265 .await
266 .map_err(fix_error)?;
267
268 Ok(())
269 }
270
271 async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
272 sqlx::query(
273 "update history
274 set deleted_at = $3
275 where user_id = $1
276 and client_id = $2
277 and deleted_at is null", )
279 .bind(user.id)
280 .bind(id)
281 .bind(OffsetDateTime::now_utc())
282 .fetch_all(&self.pool)
283 .await
284 .map_err(fix_error)?;
285
286 Ok(())
287 }
288
289 #[instrument(skip_all)]
290 async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
291 let res = sqlx::query(
296 "select client_id from history
297 where user_id = $1
298 and deleted_at is not null",
299 )
300 .bind(user.id)
301 .fetch_all(&self.pool)
302 .await
303 .map_err(fix_error)?;
304
305 let res = res
306 .iter()
307 .map(|row| row.get::<String, _>("client_id"))
308 .collect();
309
310 Ok(res)
311 }
312
313 #[instrument(skip_all)]
314 async fn count_history_range(
315 &self,
316 user: &User,
317 range: Range<OffsetDateTime>,
318 ) -> DbResult<i64> {
319 let res: (i64,) = sqlx::query_as(
320 "select count(1) from history
321 where user_id = $1
322 and timestamp >= $2::date
323 and timestamp < $3::date",
324 )
325 .bind(user.id)
326 .bind(into_utc(range.start))
327 .bind(into_utc(range.end))
328 .fetch_one(&self.pool)
329 .await
330 .map_err(fix_error)?;
331
332 Ok(res.0)
333 }
334
335 #[instrument(skip_all)]
336 async fn list_history(
337 &self,
338 user: &User,
339 created_after: OffsetDateTime,
340 since: OffsetDateTime,
341 host: &str,
342 page_size: i64,
343 ) -> DbResult<Vec<History>> {
344 let res = sqlx::query_as(
345 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
346 where user_id = $1
347 and hostname != $2
348 and created_at >= $3
349 and timestamp >= $4
350 order by timestamp asc
351 limit $5",
352 )
353 .bind(user.id)
354 .bind(host)
355 .bind(into_utc(created_after))
356 .bind(into_utc(since))
357 .bind(page_size)
358 .fetch(&self.pool)
359 .map_ok(|DbHistory(h)| h)
360 .try_collect()
361 .await
362 .map_err(fix_error)?;
363
364 Ok(res)
365 }
366
367 #[instrument(skip_all)]
368 async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
369 let mut tx = self.pool.begin().await.map_err(fix_error)?;
370
371 for i in history {
372 let client_id: &str = &i.client_id;
373 let hostname: &str = &i.hostname;
374 let data: &str = &i.data;
375
376 sqlx::query(
377 "insert into history
378 (client_id, user_id, hostname, timestamp, data)
379 values ($1, $2, $3, $4, $5)
380 on conflict do nothing
381 ",
382 )
383 .bind(client_id)
384 .bind(i.user_id)
385 .bind(hostname)
386 .bind(i.timestamp)
387 .bind(data)
388 .execute(&mut *tx)
389 .await
390 .map_err(fix_error)?;
391 }
392
393 tx.commit().await.map_err(fix_error)?;
394
395 Ok(())
396 }
397
398 #[instrument(skip_all)]
399 async fn delete_user(&self, u: &User) -> DbResult<()> {
400 sqlx::query("delete from sessions where user_id = $1")
401 .bind(u.id)
402 .execute(&self.pool)
403 .await
404 .map_err(fix_error)?;
405
406 sqlx::query("delete from history where user_id = $1")
407 .bind(u.id)
408 .execute(&self.pool)
409 .await
410 .map_err(fix_error)?;
411
412 sqlx::query("delete from store where user_id = $1")
413 .bind(u.id)
414 .execute(&self.pool)
415 .await
416 .map_err(fix_error)?;
417
418 sqlx::query("delete from user_verification_token where user_id = $1")
419 .bind(u.id)
420 .execute(&self.pool)
421 .await
422 .map_err(fix_error)?;
423
424 sqlx::query("delete from total_history_count_user where user_id = $1")
425 .bind(u.id)
426 .execute(&self.pool)
427 .await
428 .map_err(fix_error)?;
429
430 sqlx::query("delete from users where id = $1")
431 .bind(u.id)
432 .execute(&self.pool)
433 .await
434 .map_err(fix_error)?;
435
436 Ok(())
437 }
438
439 #[instrument(skip_all)]
440 async fn update_user_password(&self, user: &User) -> DbResult<()> {
441 sqlx::query(
442 "update users
443 set password = $1
444 where id = $2",
445 )
446 .bind(&user.password)
447 .bind(user.id)
448 .execute(&self.pool)
449 .await
450 .map_err(fix_error)?;
451
452 Ok(())
453 }
454
455 #[instrument(skip_all)]
456 async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
457 let email: &str = &user.email;
458 let username: &str = &user.username;
459 let password: &str = &user.password;
460
461 let res: (i64,) = sqlx::query_as(
462 "insert into users
463 (username, email, password)
464 values($1, $2, $3)
465 returning id",
466 )
467 .bind(username)
468 .bind(email)
469 .bind(password)
470 .fetch_one(&self.pool)
471 .await
472 .map_err(fix_error)?;
473
474 Ok(res.0)
475 }
476
477 #[instrument(skip_all)]
478 async fn add_session(&self, session: &NewSession) -> DbResult<()> {
479 let token: &str = &session.token;
480
481 sqlx::query(
482 "insert into sessions
483 (user_id, token)
484 values($1, $2)",
485 )
486 .bind(session.user_id)
487 .bind(token)
488 .execute(&self.pool)
489 .await
490 .map_err(fix_error)?;
491
492 Ok(())
493 }
494
495 #[instrument(skip_all)]
496 async fn get_user_session(&self, u: &User) -> DbResult<Session> {
497 sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
498 .bind(u.id)
499 .fetch_one(&self.pool)
500 .await
501 .map_err(fix_error)
502 .map(|DbSession(session)| session)
503 }
504
505 #[instrument(skip_all)]
506 async fn oldest_history(&self, user: &User) -> DbResult<History> {
507 sqlx::query_as(
508 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
509 where user_id = $1
510 order by timestamp asc
511 limit 1",
512 )
513 .bind(user.id)
514 .fetch_one(&self.pool)
515 .await
516 .map_err(fix_error)
517 .map(|DbHistory(h)| h)
518 }
519
520 #[instrument(skip_all)]
521 async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
522 let mut tx = self.pool.begin().await.map_err(fix_error)?;
523
524 let mut heads = HashMap::<(HostId, &str), u64>::new();
532
533 for i in records {
534 let id = atuin_common::utils::uuid_v7();
535
536 sqlx::query(
537 "insert into store
538 (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
539 values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
540 on conflict do nothing
541 ",
542 )
543 .bind(id)
544 .bind(i.id)
545 .bind(i.host.id)
546 .bind(i.idx as i64)
547 .bind(i.timestamp as i64) .bind(&i.version)
549 .bind(&i.tag)
550 .bind(&i.data.data)
551 .bind(&i.data.content_encryption_key)
552 .bind(user.id)
553 .execute(&mut *tx)
554 .await
555 .map_err(fix_error)?;
556
557 heads
559 .entry((i.host.id, &i.tag))
560 .and_modify(|e| {
561 if i.idx > *e {
562 *e = i.idx
563 }
564 })
565 .or_insert(i.idx);
566 }
567
568 for ((host, tag), idx) in heads {
570 sqlx::query(
571 "insert into store_idx_cache
572 (user_id, host, tag, idx)
573 values ($1, $2, $3, $4)
574 on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4)
575 ",
576 )
577 .bind(user.id)
578 .bind(host)
579 .bind(tag)
580 .bind(idx as i64)
581 .execute(&mut *tx)
582 .await
583 .map_err(fix_error)?;
584 }
585
586 tx.commit().await.map_err(fix_error)?;
587
588 Ok(())
589 }
590
591 #[instrument(skip_all)]
592 async fn next_records(
593 &self,
594 user: &User,
595 host: HostId,
596 tag: String,
597 start: Option<RecordIdx>,
598 count: u64,
599 ) -> DbResult<Vec<Record<EncryptedData>>> {
600 tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
601 let start = start.unwrap_or(0);
602
603 let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
604 "select client_id, host, idx, timestamp, version, tag, data, cek from store
605 where user_id = $1
606 and tag = $2
607 and host = $3
608 and idx >= $4
609 order by idx asc
610 limit $5",
611 )
612 .bind(user.id)
613 .bind(tag.clone())
614 .bind(host)
615 .bind(start as i64)
616 .bind(count as i64)
617 .fetch_all(&self.pool)
618 .await
619 .map_err(fix_error);
620
621 let ret = match records {
622 Ok(records) => {
623 let records: Vec<Record<EncryptedData>> = records
624 .into_iter()
625 .map(|f| {
626 let record: Record<EncryptedData> = f.into();
627 record
628 })
629 .collect();
630
631 records
632 }
633 Err(DbError::NotFound) => {
634 tracing::debug!("no records found in store: {:?}/{}", host, tag);
635 return Ok(vec![]);
636 }
637 Err(e) => return Err(e),
638 };
639
640 Ok(ret)
641 }
642
643 async fn status(&self, user: &User) -> DbResult<RecordStatus> {
644 const STATUS_SQL: &str =
645 "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
646
647 let mut res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
648 .bind(user.id)
649 .fetch_all(&self.pool)
650 .await
651 .map_err(fix_error)?;
652 res.sort();
653
654 let mut cached_res: Vec<(Uuid, String, i64)> =
661 sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
662 .bind(user.id)
663 .fetch_all(&self.pool)
664 .await
665 .map_err(fix_error)?;
666 cached_res.sort();
667
668 let mut status = RecordStatus::new();
669
670 let equal = res == cached_res;
671
672 if equal {
673 counter!("atuin_store_idx_cache_consistent", 1);
674 } else {
675 tracing::debug!(user = user.username, cache_match = equal, res = ?res, cached = ?cached_res, "record store index request");
677 counter!("atuin_store_idx_cache_inconsistent", 1);
678 };
679
680 for i in res.iter() {
681 status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64);
682 }
683
684 Ok(status)
685 }
686}
687
688fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
689 let x = x.to_offset(UtcOffset::UTC);
690 PrimitiveDateTime::new(x.date(), x.time())
691}
692
693#[cfg(test)]
694mod tests {
695 use time::macros::datetime;
696
697 use crate::into_utc;
698
699 #[test]
700 fn utc() {
701 let dt = datetime!(2023-09-26 15:11:02 +05:30);
702 assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02));
703 assert_eq!(into_utc(dt).assume_utc(), dt);
704
705 let dt = datetime!(2023-09-26 15:11:02 -07:00);
706 assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02));
707 assert_eq!(into_utc(dt).assume_utc(), dt);
708
709 let dt = datetime!(2023-09-26 15:11:02 +00:00);
710 assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02));
711 assert_eq!(into_utc(dt).assume_utc(), dt);
712 }
713}