atuin_server/
lib.rs

1#![forbid(unsafe_code)]
2
3use std::future::Future;
4use std::net::SocketAddr;
5
6use atuin_server_database::Database;
7use axum::{serve, Router};
8use axum_server::tls_rustls::RustlsConfig;
9use axum_server::Handle;
10use eyre::{eyre, Context, Result};
11
12mod handlers;
13mod metrics;
14mod router;
15mod utils;
16
17pub use settings::example_config;
18pub use settings::Settings;
19
20pub mod settings;
21
22use tokio::net::TcpListener;
23use tokio::signal;
24
25#[cfg(target_family = "unix")]
26async fn shutdown_signal() {
27    let mut term = signal::unix::signal(signal::unix::SignalKind::terminate())
28        .expect("failed to register signal handler");
29    let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt())
30        .expect("failed to register signal handler");
31
32    tokio::select! {
33        _ = term.recv() => {},
34        _ = interrupt.recv() => {},
35    };
36    eprintln!("Shutting down gracefully...");
37}
38
39#[cfg(target_family = "windows")]
40async fn shutdown_signal() {
41    signal::windows::ctrl_c()
42        .expect("failed to register signal handler")
43        .recv()
44        .await;
45    eprintln!("Shutting down gracefully...");
46}
47
48pub async fn launch<Db: Database>(
49    settings: Settings<Db::Settings>,
50    addr: SocketAddr,
51) -> Result<()> {
52    if settings.tls.enable {
53        launch_with_tls::<Db>(settings, addr, shutdown_signal()).await
54    } else {
55        launch_with_tcp_listener::<Db>(
56            settings,
57            TcpListener::bind(addr)
58                .await
59                .context("could not connect to socket")?,
60            shutdown_signal(),
61        )
62        .await
63    }
64}
65
66pub async fn launch_with_tcp_listener<Db: Database>(
67    settings: Settings<Db::Settings>,
68    listener: TcpListener,
69    shutdown: impl Future<Output = ()> + Send + 'static,
70) -> Result<()> {
71    let r = make_router::<Db>(settings).await?;
72
73    serve(listener, r.into_make_service())
74        .with_graceful_shutdown(shutdown)
75        .await?;
76
77    Ok(())
78}
79
80async fn launch_with_tls<Db: Database>(
81    settings: Settings<Db::Settings>,
82    addr: SocketAddr,
83    shutdown: impl Future<Output = ()>,
84) -> Result<()> {
85    let crypto_provider = rustls::crypto::ring::default_provider().install_default();
86    if crypto_provider.is_err() {
87        return Err(eyre!("Failed to install default crypto provider"));
88    }
89    let rustls_config = RustlsConfig::from_pem_file(
90        settings.tls.cert_path.clone(),
91        settings.tls.pkey_path.clone(),
92    )
93    .await;
94    if rustls_config.is_err() {
95        return Err(eyre!("Failed to load TLS key and/or certificate"));
96    }
97    let rustls_config = rustls_config.unwrap();
98
99    let r = make_router::<Db>(settings).await?;
100
101    let handle = Handle::new();
102
103    let server = axum_server::bind_rustls(addr, rustls_config)
104        .handle(handle.clone())
105        .serve(r.into_make_service());
106
107    tokio::select! {
108        _ = server => {}
109        _ = shutdown => {
110            handle.graceful_shutdown(None);
111        }
112    }
113
114    Ok(())
115}
116
117// The separate listener means it's much easier to ensure metrics are not accidentally exposed to
118// the public.
119pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
120    let listener = TcpListener::bind((host, port))
121        .await
122        .context("failed to bind metrics tcp")?;
123
124    let recorder_handle = metrics::setup_metrics_recorder();
125
126    let router = Router::new().route(
127        "/metrics",
128        axum::routing::get(move || std::future::ready(recorder_handle.render())),
129    );
130
131    serve(listener, router.into_make_service())
132        .with_graceful_shutdown(shutdown_signal())
133        .await?;
134
135    Ok(())
136}
137
138async fn make_router<Db: Database>(
139    settings: Settings<<Db as Database>::Settings>,
140) -> Result<Router, eyre::Error> {
141    let db = Db::new(&settings.db_settings)
142        .await
143        .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
144    let r = router::router(db, settings);
145    Ok(r)
146}