1use eyre::WrapErr;
2
3use atuin_client::encryption;
4use atuin_client::history::store::HistoryStore;
5use atuin_client::record::sqlite_store::SqliteStore;
6use atuin_client::settings::Settings;
7use std::path::PathBuf;
8use std::sync::Arc;
9use time::OffsetDateTime;
10use tracing::{instrument, Level};
11
12use atuin_client::database::{Database, Sqlite as HistoryDatabase};
13use atuin_client::history::{History, HistoryId};
14use dashmap::DashMap;
15use eyre::Result;
16use tonic::{transport::Server, Request, Response, Status};
17
18use crate::history::history_server::{History as HistorySvc, HistoryServer};
19
20use crate::history::{EndHistoryReply, EndHistoryRequest, StartHistoryReply, StartHistoryRequest};
21
22mod sync;
23
24#[derive(Debug)]
25pub struct HistoryService {
26 running: Arc<DashMap<HistoryId, History>>,
29 store: HistoryStore,
30 history_db: HistoryDatabase,
31}
32
33impl HistoryService {
34 pub fn new(store: HistoryStore, history_db: HistoryDatabase) -> Self {
35 Self {
36 running: Arc::new(DashMap::new()),
37 store,
38 history_db,
39 }
40 }
41}
42
43#[tonic::async_trait()]
44impl HistorySvc for HistoryService {
45 #[instrument(skip_all, level = Level::INFO)]
46 async fn start_history(
47 &self,
48 request: Request<StartHistoryRequest>,
49 ) -> Result<Response<StartHistoryReply>, Status> {
50 let running = self.running.clone();
51 let req = request.into_inner();
52
53 let timestamp =
54 OffsetDateTime::from_unix_timestamp_nanos(req.timestamp as i128).map_err(|_| {
55 Status::invalid_argument(
56 "failed to parse timestamp as unix time (expected nanos since epoch)",
57 )
58 })?;
59
60 let h: History = History::daemon()
61 .timestamp(timestamp)
62 .command(req.command)
63 .cwd(req.cwd)
64 .session(req.session)
65 .hostname(req.hostname)
66 .build()
67 .into();
68
69 let id = h.id.clone();
76 tracing::info!(id = id.to_string(), "start history");
77 running.insert(id.clone(), h);
78
79 let reply = StartHistoryReply { id: id.to_string() };
80
81 Ok(Response::new(reply))
82 }
83
84 #[instrument(skip_all, level = Level::INFO)]
85 async fn end_history(
86 &self,
87 request: Request<EndHistoryRequest>,
88 ) -> Result<Response<EndHistoryReply>, Status> {
89 let running = self.running.clone();
90 let req = request.into_inner();
91
92 let id = HistoryId(req.id);
93
94 if let Some((_, mut history)) = running.remove(&id) {
95 history.exit = req.exit;
96 history.duration = match req.duration {
97 0 => i64::try_from(
98 (OffsetDateTime::now_utc() - history.timestamp).whole_nanoseconds(),
99 )
100 .expect("failed to convert calculated duration to i64"),
101 value => i64::try_from(value).expect("failed to get i64 duration"),
102 };
103
104 self.history_db
106 .save(&history)
107 .await
108 .map_err(|e| Status::internal(format!("failed to write to db: {e:?}")))?;
109
110 tracing::info!(
111 id = id.0.to_string(),
112 duration = history.duration,
113 "end history"
114 );
115
116 let (id, idx) =
117 self.store.push(history).await.map_err(|e| {
118 Status::internal(format!("failed to push record to store: {e:?}"))
119 })?;
120
121 let reply = EndHistoryReply {
122 id: id.0.to_string(),
123 idx,
124 };
125
126 return Ok(Response::new(reply));
127 }
128
129 Err(Status::not_found(format!(
130 "could not find history with id: {id}"
131 )))
132 }
133}
134
135#[cfg(unix)]
136async fn shutdown_signal(socket: Option<PathBuf>) {
137 let mut term = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
138 .expect("failed to register sigterm handler");
139 let mut int = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
140 .expect("failed to register sigint handler");
141
142 tokio::select! {
143 _ = term.recv() => {},
144 _ = int.recv() => {},
145 }
146
147 eprintln!("Removing socket...");
148 if let Some(socket) = socket {
149 std::fs::remove_file(socket).expect("failed to remove socket");
150 }
151 eprintln!("Shutting down...");
152}
153
154#[cfg(windows)]
155async fn shutdown_signal() {
156 tokio::signal::windows::ctrl_c()
157 .expect("failed to register signal handler")
158 .recv()
159 .await;
160 eprintln!("Shutting down...");
161}
162
163#[cfg(unix)]
164async fn start_server(settings: Settings, history: HistoryService) -> Result<()> {
165 use tokio::net::UnixListener;
166 use tokio_stream::wrappers::UnixListenerStream;
167
168 let socket_path = settings.daemon.socket_path;
169
170 let (uds, cleanup) = if cfg!(target_os = "linux") && settings.daemon.systemd_socket {
171 #[cfg(target_os = "linux")]
172 {
173 use eyre::OptionExt;
174 tracing::info!("getting systemd socket");
175 let listener = listenfd::ListenFd::from_env()
176 .take_unix_listener(0)?
177 .ok_or_eyre("missing systemd socket")?;
178 listener.set_nonblocking(true)?;
179 let actual_path = listener
180 .local_addr()
181 .context("getting systemd socket's path")
182 .and_then(|addr| {
183 addr.as_pathname()
184 .ok_or_eyre("systemd socket missing path")
185 .map(|path| path.to_owned())
186 });
187 match actual_path {
188 Ok(actual_path) => {
189 tracing::info!("listening on systemd socket: {actual_path:?}");
190 if actual_path != std::path::Path::new(&socket_path) {
191 tracing::warn!(
192 "systemd socket is not at configured client path: {socket_path:?}"
193 );
194 }
195 }
196 Err(err) => {
197 tracing::warn!("could not detect systemd socket path, ensure that it's at the configured path: {socket_path:?}, error: {err:?}");
198 }
199 }
200 (UnixListener::from_std(listener)?, false)
201 }
202 #[cfg(not(target_os = "linux"))]
203 unreachable!()
204 } else {
205 tracing::info!("listening on unix socket {socket_path:?}");
206 (UnixListener::bind(socket_path.clone())?, true)
207 };
208
209 let uds_stream = UnixListenerStream::new(uds);
210
211 Server::builder()
212 .add_service(HistoryServer::new(history))
213 .serve_with_incoming_shutdown(
214 uds_stream,
215 shutdown_signal(cleanup.then_some(socket_path.into())),
216 )
217 .await?;
218
219 Ok(())
220}
221
222#[cfg(not(unix))]
223async fn start_server(settings: Settings, history: HistoryService) -> Result<()> {
224 use tokio::net::TcpListener;
225 use tokio_stream::wrappers::TcpListenerStream;
226
227 let port = settings.daemon.tcp_port;
228 let url = format!("127.0.0.1:{}", port);
229 let tcp = TcpListener::bind(url).await?;
230 let tcp_stream = TcpListenerStream::new(tcp);
231
232 tracing::info!("listening on tcp port {:?}", port);
233
234 Server::builder()
235 .add_service(HistoryServer::new(history))
236 .serve_with_incoming_shutdown(tcp_stream, shutdown_signal())
237 .await?;
238 Ok(())
239}
240
241pub async fn listen(
246 settings: Settings,
247 store: SqliteStore,
248 history_db: HistoryDatabase,
249) -> Result<()> {
250 let encryption_key: [u8; 32] = encryption::load_key(&settings)
251 .context("could not load encryption key")?
252 .into();
253
254 let host_id = Settings::host_id().expect("failed to get host_id");
255 let history_store = HistoryStore::new(store.clone(), host_id, encryption_key);
256
257 let history = HistoryService::new(history_store.clone(), history_db.clone());
258
259 tokio::spawn(sync::worker(
261 settings.clone(),
262 store,
263 history_store,
264 history_db,
265 ));
266
267 start_server(settings, history).await
268}