atuin_client/
api_client.rs

1use std::collections::HashMap;
2use std::env;
3use std::time::Duration;
4
5use eyre::{bail, Result};
6use reqwest::{
7    header::{HeaderMap, AUTHORIZATION, USER_AGENT},
8    Response, StatusCode, Url,
9};
10
11use atuin_common::{
12    api::{
13        AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest,
14        ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse,
15        SendVerificationResponse, StatusResponse, SyncHistoryResponse, VerificationTokenRequest,
16        VerificationTokenResponse,
17    },
18    record::RecordStatus,
19};
20use atuin_common::{
21    api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION},
22    record::{EncryptedData, HostId, Record, RecordIdx},
23};
24
25use semver::Version;
26use time::format_description::well_known::Rfc3339;
27use time::OffsetDateTime;
28
29use crate::{history::History, sync::hash_str, utils::get_host_user};
30
31static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),);
32
33pub struct Client<'a> {
34    sync_addr: &'a str,
35    client: reqwest::Client,
36}
37
38pub async fn register(
39    address: &str,
40    username: &str,
41    email: &str,
42    password: &str,
43) -> Result<RegisterResponse> {
44    let mut map = HashMap::new();
45    map.insert("username", username);
46    map.insert("email", email);
47    map.insert("password", password);
48
49    let url = format!("{address}/user/{username}");
50    let resp = reqwest::get(url).await?;
51
52    if resp.status().is_success() {
53        bail!("username already in use");
54    }
55
56    let url = format!("{address}/register");
57    let client = reqwest::Client::new();
58    let resp = client
59        .post(url)
60        .header(USER_AGENT, APP_USER_AGENT)
61        .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION)
62        .json(&map)
63        .send()
64        .await?;
65    let resp = handle_resp_error(resp).await?;
66
67    if !ensure_version(&resp)? {
68        bail!("could not register user due to version mismatch");
69    }
70
71    let session = resp.json::<RegisterResponse>().await?;
72    Ok(session)
73}
74
75pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> {
76    let url = format!("{address}/login");
77    let client = reqwest::Client::new();
78
79    let resp = client
80        .post(url)
81        .header(USER_AGENT, APP_USER_AGENT)
82        .json(&req)
83        .send()
84        .await?;
85    let resp = handle_resp_error(resp).await?;
86
87    if !ensure_version(&resp)? {
88        bail!("Could not login due to version mismatch");
89    }
90
91    let session = resp.json::<LoginResponse>().await?;
92    Ok(session)
93}
94
95#[cfg(feature = "check-update")]
96pub async fn latest_version() -> Result<Version> {
97    use atuin_common::api::IndexResponse;
98
99    let url = "https://api.atuin.sh";
100    let client = reqwest::Client::new();
101
102    let resp = client
103        .get(url)
104        .header(USER_AGENT, APP_USER_AGENT)
105        .send()
106        .await?;
107    let resp = handle_resp_error(resp).await?;
108
109    let index = resp.json::<IndexResponse>().await?;
110    let version = Version::parse(index.version.as_str())?;
111
112    Ok(version)
113}
114
115pub fn ensure_version(response: &Response) -> Result<bool> {
116    let version = response.headers().get(ATUIN_HEADER_VERSION);
117
118    let version = if let Some(version) = version {
119        match version.to_str() {
120            Ok(v) => Version::parse(v),
121            Err(e) => bail!("failed to parse server version: {:?}", e),
122        }
123    } else {
124        bail!("Server not reporting its version: it is either too old or unhealthy");
125    }?;
126
127    // If the client is newer than the server
128    if version.major < ATUIN_VERSION.major {
129        println!("Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin");
130        println!("Client: {}", ATUIN_CARGO_VERSION);
131        println!("Server: {}", version);
132
133        return Ok(false);
134    }
135
136    Ok(true)
137}
138
139async fn handle_resp_error(resp: Response) -> Result<Response> {
140    let status = resp.status();
141
142    if status == StatusCode::SERVICE_UNAVAILABLE {
143        bail!(
144            "Service unavailable: check https://status.atuin.sh (or get in touch with your host)"
145        );
146    }
147
148    if status == StatusCode::TOO_MANY_REQUESTS {
149        bail!("Rate limited; please wait before doing that again");
150    }
151
152    if !status.is_success() {
153        if let Ok(error) = resp.json::<ErrorResponse>().await {
154            let reason = error.reason;
155
156            if status.is_client_error() {
157                bail!("Invalid request to the service: {status} - {reason}.")
158            }
159
160            bail!("There was an error with the atuin sync service, server error {status}: {reason}.\nIf the problem persists, contact the host")
161        }
162
163        bail!("There was an error with the atuin sync service: Status {status:?}.\nIf the problem persists, contact the host")
164    }
165
166    Ok(resp)
167}
168
169impl<'a> Client<'a> {
170    pub fn new(
171        sync_addr: &'a str,
172        session_token: &str,
173        connect_timeout: u64,
174        timeout: u64,
175    ) -> Result<Self> {
176        let mut headers = HeaderMap::new();
177        headers.insert(AUTHORIZATION, format!("Token {session_token}").parse()?);
178
179        // used for semver server check
180        headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?);
181
182        Ok(Client {
183            sync_addr,
184            client: reqwest::Client::builder()
185                .user_agent(APP_USER_AGENT)
186                .default_headers(headers)
187                .connect_timeout(Duration::new(connect_timeout, 0))
188                .timeout(Duration::new(timeout, 0))
189                .build()?,
190        })
191    }
192
193    pub async fn count(&self) -> Result<i64> {
194        let url = format!("{}/sync/count", self.sync_addr);
195        let url = Url::parse(url.as_str())?;
196
197        let resp = self.client.get(url).send().await?;
198        let resp = handle_resp_error(resp).await?;
199
200        if !ensure_version(&resp)? {
201            bail!("could not sync due to version mismatch");
202        }
203
204        if resp.status() != StatusCode::OK {
205            bail!("failed to get count (are you logged in?)");
206        }
207
208        let count = resp.json::<CountResponse>().await?;
209
210        Ok(count.count)
211    }
212
213    pub async fn status(&self) -> Result<StatusResponse> {
214        let url = format!("{}/sync/status", self.sync_addr);
215        let url = Url::parse(url.as_str())?;
216
217        let resp = self.client.get(url).send().await?;
218        let resp = handle_resp_error(resp).await?;
219
220        if !ensure_version(&resp)? {
221            bail!("could not sync due to version mismatch");
222        }
223
224        let status = resp.json::<StatusResponse>().await?;
225
226        Ok(status)
227    }
228
229    pub async fn me(&self) -> Result<MeResponse> {
230        let url = format!("{}/api/v0/me", self.sync_addr);
231        let url = Url::parse(url.as_str())?;
232
233        let resp = self.client.get(url).send().await?;
234        let resp = handle_resp_error(resp).await?;
235
236        let status = resp.json::<MeResponse>().await?;
237
238        Ok(status)
239    }
240
241    pub async fn get_history(
242        &self,
243        sync_ts: OffsetDateTime,
244        history_ts: OffsetDateTime,
245        host: Option<String>,
246    ) -> Result<SyncHistoryResponse> {
247        let host = host.unwrap_or_else(|| hash_str(&get_host_user()));
248
249        let url = format!(
250            "{}/sync/history?sync_ts={}&history_ts={}&host={}",
251            self.sync_addr,
252            urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()),
253            urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()),
254            host,
255        );
256
257        let resp = self.client.get(url).send().await?;
258        let resp = handle_resp_error(resp).await?;
259
260        let history = resp.json::<SyncHistoryResponse>().await?;
261        Ok(history)
262    }
263
264    pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
265        let url = format!("{}/history", self.sync_addr);
266        let url = Url::parse(url.as_str())?;
267
268        let resp = self.client.post(url).json(history).send().await?;
269        handle_resp_error(resp).await?;
270
271        Ok(())
272    }
273
274    pub async fn delete_history(&self, h: History) -> Result<()> {
275        let url = format!("{}/history", self.sync_addr);
276        let url = Url::parse(url.as_str())?;
277
278        let resp = self
279            .client
280            .delete(url)
281            .json(&DeleteHistoryRequest {
282                client_id: h.id.to_string(),
283            })
284            .send()
285            .await?;
286
287        handle_resp_error(resp).await?;
288
289        Ok(())
290    }
291
292    pub async fn delete_store(&self) -> Result<()> {
293        let url = format!("{}/api/v0/store", self.sync_addr);
294        let url = Url::parse(url.as_str())?;
295
296        let resp = self.client.delete(url).send().await?;
297
298        handle_resp_error(resp).await?;
299
300        Ok(())
301    }
302
303    pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
304        let url = format!("{}/api/v0/record", self.sync_addr);
305        let url = Url::parse(url.as_str())?;
306
307        debug!("uploading {} records to {url}", records.len());
308
309        let resp = self.client.post(url).json(records).send().await?;
310        handle_resp_error(resp).await?;
311
312        Ok(())
313    }
314
315    pub async fn next_records(
316        &self,
317        host: HostId,
318        tag: String,
319        start: RecordIdx,
320        count: u64,
321    ) -> Result<Vec<Record<EncryptedData>>> {
322        debug!(
323            "fetching record/s from host {}/{}/{}",
324            host.0.to_string(),
325            tag,
326            start
327        );
328
329        let url = format!(
330            "{}/api/v0/record/next?host={}&tag={}&count={}&start={}",
331            self.sync_addr, host.0, tag, count, start
332        );
333
334        let url = Url::parse(url.as_str())?;
335
336        let resp = self.client.get(url).send().await?;
337        let resp = handle_resp_error(resp).await?;
338
339        let records = resp.json::<Vec<Record<EncryptedData>>>().await?;
340
341        Ok(records)
342    }
343
344    pub async fn record_status(&self) -> Result<RecordStatus> {
345        let url = format!("{}/api/v0/record", self.sync_addr);
346        let url = Url::parse(url.as_str())?;
347
348        let resp = self.client.get(url).send().await?;
349        let resp = handle_resp_error(resp).await?;
350
351        if !ensure_version(&resp)? {
352            bail!("could not sync records due to version mismatch");
353        }
354
355        let index = resp.json().await?;
356
357        debug!("got remote index {:?}", index);
358
359        Ok(index)
360    }
361
362    pub async fn delete(&self) -> Result<()> {
363        let url = format!("{}/account", self.sync_addr);
364        let url = Url::parse(url.as_str())?;
365
366        let resp = self.client.delete(url).send().await?;
367
368        if resp.status() == 403 {
369            bail!("invalid login details");
370        } else if resp.status() == 200 {
371            Ok(())
372        } else {
373            bail!("Unknown error");
374        }
375    }
376
377    pub async fn change_password(
378        &self,
379        current_password: String,
380        new_password: String,
381    ) -> Result<()> {
382        let url = format!("{}/account/password", self.sync_addr);
383        let url = Url::parse(url.as_str())?;
384
385        let resp = self
386            .client
387            .patch(url)
388            .json(&ChangePasswordRequest {
389                current_password,
390                new_password,
391            })
392            .send()
393            .await?;
394
395        if resp.status() == 401 {
396            bail!("current password is incorrect")
397        } else if resp.status() == 403 {
398            bail!("invalid login details");
399        } else if resp.status() == 200 {
400            Ok(())
401        } else {
402            bail!("Unknown error");
403        }
404    }
405
406    // Either request a verification email if token is null, or validate a token
407    pub async fn verify(&self, token: Option<String>) -> Result<(bool, bool)> {
408        // could dedupe this a bit, but it's simple at the moment
409        let (email_sent, verified) = if let Some(token) = token {
410            let url = format!("{}/api/v0/account/verify", self.sync_addr);
411            let url = Url::parse(url.as_str())?;
412
413            let resp = self
414                .client
415                .post(url)
416                .json(&VerificationTokenRequest { token })
417                .send()
418                .await?;
419            let resp = handle_resp_error(resp).await?;
420            let resp = resp.json::<VerificationTokenResponse>().await?;
421
422            (false, resp.verified)
423        } else {
424            let url = format!("{}/api/v0/account/send-verification", self.sync_addr);
425            let url = Url::parse(url.as_str())?;
426
427            let resp = self.client.post(url).send().await?;
428            let resp = handle_resp_error(resp).await?;
429            let resp = resp.json::<SendVerificationResponse>().await?;
430
431            (resp.email_sent, resp.verified)
432        };
433
434        Ok((email_sent, verified))
435    }
436}