1use std::collections::HashMap;
2use std::env;
3use std::time::Duration;
4
5use eyre::{bail, Result};
6use reqwest::{
7 header::{HeaderMap, AUTHORIZATION, USER_AGENT},
8 Response, StatusCode, Url,
9};
10
11use atuin_common::{
12 api::{
13 AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest,
14 ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse,
15 SendVerificationResponse, StatusResponse, SyncHistoryResponse, VerificationTokenRequest,
16 VerificationTokenResponse,
17 },
18 record::RecordStatus,
19};
20use atuin_common::{
21 api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION},
22 record::{EncryptedData, HostId, Record, RecordIdx},
23};
24
25use semver::Version;
26use time::format_description::well_known::Rfc3339;
27use time::OffsetDateTime;
28
29use crate::{history::History, sync::hash_str, utils::get_host_user};
30
31static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),);
32
33pub struct Client<'a> {
34 sync_addr: &'a str,
35 client: reqwest::Client,
36}
37
38pub async fn register(
39 address: &str,
40 username: &str,
41 email: &str,
42 password: &str,
43) -> Result<RegisterResponse> {
44 let mut map = HashMap::new();
45 map.insert("username", username);
46 map.insert("email", email);
47 map.insert("password", password);
48
49 let url = format!("{address}/user/{username}");
50 let resp = reqwest::get(url).await?;
51
52 if resp.status().is_success() {
53 bail!("username already in use");
54 }
55
56 let url = format!("{address}/register");
57 let client = reqwest::Client::new();
58 let resp = client
59 .post(url)
60 .header(USER_AGENT, APP_USER_AGENT)
61 .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION)
62 .json(&map)
63 .send()
64 .await?;
65 let resp = handle_resp_error(resp).await?;
66
67 if !ensure_version(&resp)? {
68 bail!("could not register user due to version mismatch");
69 }
70
71 let session = resp.json::<RegisterResponse>().await?;
72 Ok(session)
73}
74
75pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> {
76 let url = format!("{address}/login");
77 let client = reqwest::Client::new();
78
79 let resp = client
80 .post(url)
81 .header(USER_AGENT, APP_USER_AGENT)
82 .json(&req)
83 .send()
84 .await?;
85 let resp = handle_resp_error(resp).await?;
86
87 if !ensure_version(&resp)? {
88 bail!("Could not login due to version mismatch");
89 }
90
91 let session = resp.json::<LoginResponse>().await?;
92 Ok(session)
93}
94
95#[cfg(feature = "check-update")]
96pub async fn latest_version() -> Result<Version> {
97 use atuin_common::api::IndexResponse;
98
99 let url = "https://api.atuin.sh";
100 let client = reqwest::Client::new();
101
102 let resp = client
103 .get(url)
104 .header(USER_AGENT, APP_USER_AGENT)
105 .send()
106 .await?;
107 let resp = handle_resp_error(resp).await?;
108
109 let index = resp.json::<IndexResponse>().await?;
110 let version = Version::parse(index.version.as_str())?;
111
112 Ok(version)
113}
114
115pub fn ensure_version(response: &Response) -> Result<bool> {
116 let version = response.headers().get(ATUIN_HEADER_VERSION);
117
118 let version = if let Some(version) = version {
119 match version.to_str() {
120 Ok(v) => Version::parse(v),
121 Err(e) => bail!("failed to parse server version: {:?}", e),
122 }
123 } else {
124 bail!("Server not reporting its version: it is either too old or unhealthy");
125 }?;
126
127 if version.major < ATUIN_VERSION.major {
129 println!("Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin");
130 println!("Client: {}", ATUIN_CARGO_VERSION);
131 println!("Server: {}", version);
132
133 return Ok(false);
134 }
135
136 Ok(true)
137}
138
139async fn handle_resp_error(resp: Response) -> Result<Response> {
140 let status = resp.status();
141
142 if status == StatusCode::SERVICE_UNAVAILABLE {
143 bail!(
144 "Service unavailable: check https://status.atuin.sh (or get in touch with your host)"
145 );
146 }
147
148 if status == StatusCode::TOO_MANY_REQUESTS {
149 bail!("Rate limited; please wait before doing that again");
150 }
151
152 if !status.is_success() {
153 if let Ok(error) = resp.json::<ErrorResponse>().await {
154 let reason = error.reason;
155
156 if status.is_client_error() {
157 bail!("Invalid request to the service: {status} - {reason}.")
158 }
159
160 bail!("There was an error with the atuin sync service, server error {status}: {reason}.\nIf the problem persists, contact the host")
161 }
162
163 bail!("There was an error with the atuin sync service: Status {status:?}.\nIf the problem persists, contact the host")
164 }
165
166 Ok(resp)
167}
168
169impl<'a> Client<'a> {
170 pub fn new(
171 sync_addr: &'a str,
172 session_token: &str,
173 connect_timeout: u64,
174 timeout: u64,
175 ) -> Result<Self> {
176 let mut headers = HeaderMap::new();
177 headers.insert(AUTHORIZATION, format!("Token {session_token}").parse()?);
178
179 headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?);
181
182 Ok(Client {
183 sync_addr,
184 client: reqwest::Client::builder()
185 .user_agent(APP_USER_AGENT)
186 .default_headers(headers)
187 .connect_timeout(Duration::new(connect_timeout, 0))
188 .timeout(Duration::new(timeout, 0))
189 .build()?,
190 })
191 }
192
193 pub async fn count(&self) -> Result<i64> {
194 let url = format!("{}/sync/count", self.sync_addr);
195 let url = Url::parse(url.as_str())?;
196
197 let resp = self.client.get(url).send().await?;
198 let resp = handle_resp_error(resp).await?;
199
200 if !ensure_version(&resp)? {
201 bail!("could not sync due to version mismatch");
202 }
203
204 if resp.status() != StatusCode::OK {
205 bail!("failed to get count (are you logged in?)");
206 }
207
208 let count = resp.json::<CountResponse>().await?;
209
210 Ok(count.count)
211 }
212
213 pub async fn status(&self) -> Result<StatusResponse> {
214 let url = format!("{}/sync/status", self.sync_addr);
215 let url = Url::parse(url.as_str())?;
216
217 let resp = self.client.get(url).send().await?;
218 let resp = handle_resp_error(resp).await?;
219
220 if !ensure_version(&resp)? {
221 bail!("could not sync due to version mismatch");
222 }
223
224 let status = resp.json::<StatusResponse>().await?;
225
226 Ok(status)
227 }
228
229 pub async fn me(&self) -> Result<MeResponse> {
230 let url = format!("{}/api/v0/me", self.sync_addr);
231 let url = Url::parse(url.as_str())?;
232
233 let resp = self.client.get(url).send().await?;
234 let resp = handle_resp_error(resp).await?;
235
236 let status = resp.json::<MeResponse>().await?;
237
238 Ok(status)
239 }
240
241 pub async fn get_history(
242 &self,
243 sync_ts: OffsetDateTime,
244 history_ts: OffsetDateTime,
245 host: Option<String>,
246 ) -> Result<SyncHistoryResponse> {
247 let host = host.unwrap_or_else(|| hash_str(&get_host_user()));
248
249 let url = format!(
250 "{}/sync/history?sync_ts={}&history_ts={}&host={}",
251 self.sync_addr,
252 urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()),
253 urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()),
254 host,
255 );
256
257 let resp = self.client.get(url).send().await?;
258 let resp = handle_resp_error(resp).await?;
259
260 let history = resp.json::<SyncHistoryResponse>().await?;
261 Ok(history)
262 }
263
264 pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
265 let url = format!("{}/history", self.sync_addr);
266 let url = Url::parse(url.as_str())?;
267
268 let resp = self.client.post(url).json(history).send().await?;
269 handle_resp_error(resp).await?;
270
271 Ok(())
272 }
273
274 pub async fn delete_history(&self, h: History) -> Result<()> {
275 let url = format!("{}/history", self.sync_addr);
276 let url = Url::parse(url.as_str())?;
277
278 let resp = self
279 .client
280 .delete(url)
281 .json(&DeleteHistoryRequest {
282 client_id: h.id.to_string(),
283 })
284 .send()
285 .await?;
286
287 handle_resp_error(resp).await?;
288
289 Ok(())
290 }
291
292 pub async fn delete_store(&self) -> Result<()> {
293 let url = format!("{}/api/v0/store", self.sync_addr);
294 let url = Url::parse(url.as_str())?;
295
296 let resp = self.client.delete(url).send().await?;
297
298 handle_resp_error(resp).await?;
299
300 Ok(())
301 }
302
303 pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
304 let url = format!("{}/api/v0/record", self.sync_addr);
305 let url = Url::parse(url.as_str())?;
306
307 debug!("uploading {} records to {url}", records.len());
308
309 let resp = self.client.post(url).json(records).send().await?;
310 handle_resp_error(resp).await?;
311
312 Ok(())
313 }
314
315 pub async fn next_records(
316 &self,
317 host: HostId,
318 tag: String,
319 start: RecordIdx,
320 count: u64,
321 ) -> Result<Vec<Record<EncryptedData>>> {
322 debug!(
323 "fetching record/s from host {}/{}/{}",
324 host.0.to_string(),
325 tag,
326 start
327 );
328
329 let url = format!(
330 "{}/api/v0/record/next?host={}&tag={}&count={}&start={}",
331 self.sync_addr, host.0, tag, count, start
332 );
333
334 let url = Url::parse(url.as_str())?;
335
336 let resp = self.client.get(url).send().await?;
337 let resp = handle_resp_error(resp).await?;
338
339 let records = resp.json::<Vec<Record<EncryptedData>>>().await?;
340
341 Ok(records)
342 }
343
344 pub async fn record_status(&self) -> Result<RecordStatus> {
345 let url = format!("{}/api/v0/record", self.sync_addr);
346 let url = Url::parse(url.as_str())?;
347
348 let resp = self.client.get(url).send().await?;
349 let resp = handle_resp_error(resp).await?;
350
351 if !ensure_version(&resp)? {
352 bail!("could not sync records due to version mismatch");
353 }
354
355 let index = resp.json().await?;
356
357 debug!("got remote index {:?}", index);
358
359 Ok(index)
360 }
361
362 pub async fn delete(&self) -> Result<()> {
363 let url = format!("{}/account", self.sync_addr);
364 let url = Url::parse(url.as_str())?;
365
366 let resp = self.client.delete(url).send().await?;
367
368 if resp.status() == 403 {
369 bail!("invalid login details");
370 } else if resp.status() == 200 {
371 Ok(())
372 } else {
373 bail!("Unknown error");
374 }
375 }
376
377 pub async fn change_password(
378 &self,
379 current_password: String,
380 new_password: String,
381 ) -> Result<()> {
382 let url = format!("{}/account/password", self.sync_addr);
383 let url = Url::parse(url.as_str())?;
384
385 let resp = self
386 .client
387 .patch(url)
388 .json(&ChangePasswordRequest {
389 current_password,
390 new_password,
391 })
392 .send()
393 .await?;
394
395 if resp.status() == 401 {
396 bail!("current password is incorrect")
397 } else if resp.status() == 403 {
398 bail!("invalid login details");
399 } else if resp.status() == 200 {
400 Ok(())
401 } else {
402 bail!("Unknown error");
403 }
404 }
405
406 pub async fn verify(&self, token: Option<String>) -> Result<(bool, bool)> {
408 let (email_sent, verified) = if let Some(token) = token {
410 let url = format!("{}/api/v0/account/verify", self.sync_addr);
411 let url = Url::parse(url.as_str())?;
412
413 let resp = self
414 .client
415 .post(url)
416 .json(&VerificationTokenRequest { token })
417 .send()
418 .await?;
419 let resp = handle_resp_error(resp).await?;
420 let resp = resp.json::<VerificationTokenResponse>().await?;
421
422 (false, resp.verified)
423 } else {
424 let url = format!("{}/api/v0/account/send-verification", self.sync_addr);
425 let url = Url::parse(url.as_str())?;
426
427 let resp = self.client.post(url).send().await?;
428 let resp = handle_resp_error(resp).await?;
429 let resp = resp.json::<SendVerificationResponse>().await?;
430
431 (resp.email_sent, resp.verified)
432 };
433
434 Ok((email_sent, verified))
435 }
436}