atuin_daemon/
server.rs

1use eyre::WrapErr;
2
3use atuin_client::encryption;
4use atuin_client::history::store::HistoryStore;
5use atuin_client::record::sqlite_store::SqliteStore;
6use atuin_client::settings::Settings;
7use std::path::PathBuf;
8use std::sync::Arc;
9use time::OffsetDateTime;
10use tracing::{instrument, Level};
11
12use atuin_client::database::{Database, Sqlite as HistoryDatabase};
13use atuin_client::history::{History, HistoryId};
14use dashmap::DashMap;
15use eyre::Result;
16use tonic::{transport::Server, Request, Response, Status};
17
18use crate::history::history_server::{History as HistorySvc, HistoryServer};
19
20use crate::history::{EndHistoryReply, EndHistoryRequest, StartHistoryReply, StartHistoryRequest};
21
22mod sync;
23
24#[derive(Debug)]
25pub struct HistoryService {
26    // A store for WIP history
27    // This is history that has not yet been completed, aka a command that's current running.
28    running: Arc<DashMap<HistoryId, History>>,
29    store: HistoryStore,
30    history_db: HistoryDatabase,
31}
32
33impl HistoryService {
34    pub fn new(store: HistoryStore, history_db: HistoryDatabase) -> Self {
35        Self {
36            running: Arc::new(DashMap::new()),
37            store,
38            history_db,
39        }
40    }
41}
42
43#[tonic::async_trait()]
44impl HistorySvc for HistoryService {
45    #[instrument(skip_all, level = Level::INFO)]
46    async fn start_history(
47        &self,
48        request: Request<StartHistoryRequest>,
49    ) -> Result<Response<StartHistoryReply>, Status> {
50        let running = self.running.clone();
51        let req = request.into_inner();
52
53        let timestamp =
54            OffsetDateTime::from_unix_timestamp_nanos(req.timestamp as i128).map_err(|_| {
55                Status::invalid_argument(
56                    "failed to parse timestamp as unix time (expected nanos since epoch)",
57                )
58            })?;
59
60        let h: History = History::daemon()
61            .timestamp(timestamp)
62            .command(req.command)
63            .cwd(req.cwd)
64            .session(req.session)
65            .hostname(req.hostname)
66            .build()
67            .into();
68
69        // The old behaviour had us inserting half-finished history records into the database
70        // The new behaviour no longer allows that.
71        // History that's running is stored in-memory by the daemon, and only committed when
72        // complete.
73        // If anyone relied on the old behaviour, we could perhaps insert to the history db here
74        // too. I'd rather keep it pure, unless that ends up being the case.
75        let id = h.id.clone();
76        tracing::info!(id = id.to_string(), "start history");
77        running.insert(id.clone(), h);
78
79        let reply = StartHistoryReply { id: id.to_string() };
80
81        Ok(Response::new(reply))
82    }
83
84    #[instrument(skip_all, level = Level::INFO)]
85    async fn end_history(
86        &self,
87        request: Request<EndHistoryRequest>,
88    ) -> Result<Response<EndHistoryReply>, Status> {
89        let running = self.running.clone();
90        let req = request.into_inner();
91
92        let id = HistoryId(req.id);
93
94        if let Some((_, mut history)) = running.remove(&id) {
95            history.exit = req.exit;
96            history.duration = match req.duration {
97                0 => i64::try_from(
98                    (OffsetDateTime::now_utc() - history.timestamp).whole_nanoseconds(),
99                )
100                .expect("failed to convert calculated duration to i64"),
101                value => i64::try_from(value).expect("failed to get i64 duration"),
102            };
103
104            // Perhaps allow the incremental build to handle this entirely.
105            self.history_db
106                .save(&history)
107                .await
108                .map_err(|e| Status::internal(format!("failed to write to db: {e:?}")))?;
109
110            tracing::info!(
111                id = id.0.to_string(),
112                duration = history.duration,
113                "end history"
114            );
115
116            let (id, idx) =
117                self.store.push(history).await.map_err(|e| {
118                    Status::internal(format!("failed to push record to store: {e:?}"))
119                })?;
120
121            let reply = EndHistoryReply {
122                id: id.0.to_string(),
123                idx,
124            };
125
126            return Ok(Response::new(reply));
127        }
128
129        Err(Status::not_found(format!(
130            "could not find history with id: {id}"
131        )))
132    }
133}
134
135#[cfg(unix)]
136async fn shutdown_signal(socket: Option<PathBuf>) {
137    let mut term = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
138        .expect("failed to register sigterm handler");
139    let mut int = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
140        .expect("failed to register sigint handler");
141
142    tokio::select! {
143        _  = term.recv() => {},
144        _  = int.recv() => {},
145    }
146
147    eprintln!("Removing socket...");
148    if let Some(socket) = socket {
149        std::fs::remove_file(socket).expect("failed to remove socket");
150    }
151    eprintln!("Shutting down...");
152}
153
154#[cfg(windows)]
155async fn shutdown_signal() {
156    tokio::signal::windows::ctrl_c()
157        .expect("failed to register signal handler")
158        .recv()
159        .await;
160    eprintln!("Shutting down...");
161}
162
163#[cfg(unix)]
164async fn start_server(settings: Settings, history: HistoryService) -> Result<()> {
165    use tokio::net::UnixListener;
166    use tokio_stream::wrappers::UnixListenerStream;
167
168    let socket_path = settings.daemon.socket_path;
169
170    let (uds, cleanup) = if cfg!(target_os = "linux") && settings.daemon.systemd_socket {
171        #[cfg(target_os = "linux")]
172        {
173            use eyre::OptionExt;
174            tracing::info!("getting systemd socket");
175            let listener = listenfd::ListenFd::from_env()
176                .take_unix_listener(0)?
177                .ok_or_eyre("missing systemd socket")?;
178            listener.set_nonblocking(true)?;
179            let actual_path = listener
180                .local_addr()
181                .context("getting systemd socket's path")
182                .and_then(|addr| {
183                    addr.as_pathname()
184                        .ok_or_eyre("systemd socket missing path")
185                        .map(|path| path.to_owned())
186                });
187            match actual_path {
188                Ok(actual_path) => {
189                    tracing::info!("listening on systemd socket: {actual_path:?}");
190                    if actual_path != std::path::Path::new(&socket_path) {
191                        tracing::warn!(
192                            "systemd socket is not at configured client path: {socket_path:?}"
193                        );
194                    }
195                }
196                Err(err) => {
197                    tracing::warn!("could not detect systemd socket path, ensure that it's at the configured path: {socket_path:?}, error: {err:?}");
198                }
199            }
200            (UnixListener::from_std(listener)?, false)
201        }
202        #[cfg(not(target_os = "linux"))]
203        unreachable!()
204    } else {
205        tracing::info!("listening on unix socket {socket_path:?}");
206        (UnixListener::bind(socket_path.clone())?, true)
207    };
208
209    let uds_stream = UnixListenerStream::new(uds);
210
211    Server::builder()
212        .add_service(HistoryServer::new(history))
213        .serve_with_incoming_shutdown(
214            uds_stream,
215            shutdown_signal(cleanup.then_some(socket_path.into())),
216        )
217        .await?;
218
219    Ok(())
220}
221
222#[cfg(not(unix))]
223async fn start_server(settings: Settings, history: HistoryService) -> Result<()> {
224    use tokio::net::TcpListener;
225    use tokio_stream::wrappers::TcpListenerStream;
226
227    let port = settings.daemon.tcp_port;
228    let url = format!("127.0.0.1:{}", port);
229    let tcp = TcpListener::bind(url).await?;
230    let tcp_stream = TcpListenerStream::new(tcp);
231
232    tracing::info!("listening on tcp port {:?}", port);
233
234    Server::builder()
235        .add_service(HistoryServer::new(history))
236        .serve_with_incoming_shutdown(tcp_stream, shutdown_signal())
237        .await?;
238    Ok(())
239}
240
241// break the above down when we end up with multiple services
242
243/// Listen on a unix socket
244/// Pass the path to the socket
245pub async fn listen(
246    settings: Settings,
247    store: SqliteStore,
248    history_db: HistoryDatabase,
249) -> Result<()> {
250    let encryption_key: [u8; 32] = encryption::load_key(&settings)
251        .context("could not load encryption key")?
252        .into();
253
254    let host_id = Settings::host_id().expect("failed to get host_id");
255    let history_store = HistoryStore::new(store.clone(), host_id, encryption_key);
256
257    let history = HistoryService::new(history_store.clone(), history_db.clone());
258
259    // start services
260    tokio::spawn(sync::worker(
261        settings.clone(),
262        store,
263        history_store,
264        history_db,
265    ));
266
267    start_server(settings, history).await
268}