1#![forbid(unsafe_code)]
2
3use std::future::Future;
4use std::net::SocketAddr;
5
6use atuin_server_database::Database;
7use axum::{serve, Router};
8use axum_server::tls_rustls::RustlsConfig;
9use axum_server::Handle;
10use eyre::{eyre, Context, Result};
11
12mod handlers;
13mod metrics;
14mod router;
15mod utils;
16
17pub use settings::example_config;
18pub use settings::Settings;
19
20pub mod settings;
21
22use tokio::net::TcpListener;
23use tokio::signal;
24
25#[cfg(target_family = "unix")]
26async fn shutdown_signal() {
27 let mut term = signal::unix::signal(signal::unix::SignalKind::terminate())
28 .expect("failed to register signal handler");
29 let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt())
30 .expect("failed to register signal handler");
31
32 tokio::select! {
33 _ = term.recv() => {},
34 _ = interrupt.recv() => {},
35 };
36 eprintln!("Shutting down gracefully...");
37}
38
39#[cfg(target_family = "windows")]
40async fn shutdown_signal() {
41 signal::windows::ctrl_c()
42 .expect("failed to register signal handler")
43 .recv()
44 .await;
45 eprintln!("Shutting down gracefully...");
46}
47
48pub async fn launch<Db: Database>(
49 settings: Settings<Db::Settings>,
50 addr: SocketAddr,
51) -> Result<()> {
52 if settings.tls.enable {
53 launch_with_tls::<Db>(settings, addr, shutdown_signal()).await
54 } else {
55 launch_with_tcp_listener::<Db>(
56 settings,
57 TcpListener::bind(addr)
58 .await
59 .context("could not connect to socket")?,
60 shutdown_signal(),
61 )
62 .await
63 }
64}
65
66pub async fn launch_with_tcp_listener<Db: Database>(
67 settings: Settings<Db::Settings>,
68 listener: TcpListener,
69 shutdown: impl Future<Output = ()> + Send + 'static,
70) -> Result<()> {
71 let r = make_router::<Db>(settings).await?;
72
73 serve(listener, r.into_make_service())
74 .with_graceful_shutdown(shutdown)
75 .await?;
76
77 Ok(())
78}
79
80async fn launch_with_tls<Db: Database>(
81 settings: Settings<Db::Settings>,
82 addr: SocketAddr,
83 shutdown: impl Future<Output = ()>,
84) -> Result<()> {
85 let crypto_provider = rustls::crypto::ring::default_provider().install_default();
86 if crypto_provider.is_err() {
87 return Err(eyre!("Failed to install default crypto provider"));
88 }
89 let rustls_config = RustlsConfig::from_pem_file(
90 settings.tls.cert_path.clone(),
91 settings.tls.pkey_path.clone(),
92 )
93 .await;
94 if rustls_config.is_err() {
95 return Err(eyre!("Failed to load TLS key and/or certificate"));
96 }
97 let rustls_config = rustls_config.unwrap();
98
99 let r = make_router::<Db>(settings).await?;
100
101 let handle = Handle::new();
102
103 let server = axum_server::bind_rustls(addr, rustls_config)
104 .handle(handle.clone())
105 .serve(r.into_make_service());
106
107 tokio::select! {
108 _ = server => {}
109 _ = shutdown => {
110 handle.graceful_shutdown(None);
111 }
112 }
113
114 Ok(())
115}
116
117pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
120 let listener = TcpListener::bind((host, port))
121 .await
122 .context("failed to bind metrics tcp")?;
123
124 let recorder_handle = metrics::setup_metrics_recorder();
125
126 let router = Router::new().route(
127 "/metrics",
128 axum::routing::get(move || std::future::ready(recorder_handle.render())),
129 );
130
131 serve(listener, router.into_make_service())
132 .with_graceful_shutdown(shutdown_signal())
133 .await?;
134
135 Ok(())
136}
137
138async fn make_router<Db: Database>(
139 settings: Settings<<Db as Database>::Settings>,
140) -> Result<Router, eyre::Error> {
141 let db = Db::new(&settings.db_settings)
142 .await
143 .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
144 let r = router::router(db, settings);
145 Ok(r)
146}