atuin_client/
api_client.rs

1use std::collections::HashMap;
2use std::env;
3use std::time::Duration;
4
5use eyre::{Result, bail};
6use reqwest::{
7    Response, StatusCode, Url,
8    header::{AUTHORIZATION, HeaderMap, USER_AGENT},
9};
10
11use atuin_common::{
12    api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION},
13    record::{EncryptedData, HostId, Record, RecordIdx},
14};
15use atuin_common::{
16    api::{
17        AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest,
18        ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse,
19        SendVerificationResponse, StatusResponse, SyncHistoryResponse, VerificationTokenRequest,
20        VerificationTokenResponse,
21    },
22    record::RecordStatus,
23};
24
25use semver::Version;
26use time::OffsetDateTime;
27use time::format_description::well_known::Rfc3339;
28
29use crate::{history::History, sync::hash_str, utils::get_host_user};
30
31static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),);
32
33pub struct Client<'a> {
34    sync_addr: &'a str,
35    client: reqwest::Client,
36}
37
38pub async fn register(
39    address: &str,
40    username: &str,
41    email: &str,
42    password: &str,
43) -> Result<RegisterResponse> {
44    let mut map = HashMap::new();
45    map.insert("username", username);
46    map.insert("email", email);
47    map.insert("password", password);
48
49    let url = format!("{address}/user/{username}");
50    let resp = reqwest::get(url).await?;
51
52    if resp.status().is_success() {
53        bail!("username already in use");
54    }
55
56    let url = format!("{address}/register");
57    let client = reqwest::Client::new();
58    let resp = client
59        .post(url)
60        .header(USER_AGENT, APP_USER_AGENT)
61        .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION)
62        .json(&map)
63        .send()
64        .await?;
65    let resp = handle_resp_error(resp).await?;
66
67    if !ensure_version(&resp)? {
68        bail!("could not register user due to version mismatch");
69    }
70
71    let session = resp.json::<RegisterResponse>().await?;
72    Ok(session)
73}
74
75pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> {
76    let url = format!("{address}/login");
77    let client = reqwest::Client::new();
78
79    let resp = client
80        .post(url)
81        .header(USER_AGENT, APP_USER_AGENT)
82        .json(&req)
83        .send()
84        .await?;
85    let resp = handle_resp_error(resp).await?;
86
87    if !ensure_version(&resp)? {
88        bail!("Could not login due to version mismatch");
89    }
90
91    let session = resp.json::<LoginResponse>().await?;
92    Ok(session)
93}
94
95#[cfg(feature = "check-update")]
96pub async fn latest_version() -> Result<Version> {
97    use atuin_common::api::IndexResponse;
98
99    let url = "https://api.atuin.sh";
100    let client = reqwest::Client::new();
101
102    let resp = client
103        .get(url)
104        .header(USER_AGENT, APP_USER_AGENT)
105        .send()
106        .await?;
107    let resp = handle_resp_error(resp).await?;
108
109    let index = resp.json::<IndexResponse>().await?;
110    let version = Version::parse(index.version.as_str())?;
111
112    Ok(version)
113}
114
115pub fn ensure_version(response: &Response) -> Result<bool> {
116    let version = response.headers().get(ATUIN_HEADER_VERSION);
117
118    let version = if let Some(version) = version {
119        match version.to_str() {
120            Ok(v) => Version::parse(v),
121            Err(e) => bail!("failed to parse server version: {:?}", e),
122        }
123    } else {
124        bail!("Server not reporting its version: it is either too old or unhealthy");
125    }?;
126
127    // If the client is newer than the server
128    if version.major < ATUIN_VERSION.major {
129        println!(
130            "Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin"
131        );
132        println!("Client: {}", ATUIN_CARGO_VERSION);
133        println!("Server: {}", version);
134
135        return Ok(false);
136    }
137
138    Ok(true)
139}
140
141async fn handle_resp_error(resp: Response) -> Result<Response> {
142    let status = resp.status();
143
144    if status == StatusCode::SERVICE_UNAVAILABLE {
145        bail!(
146            "Service unavailable: check https://status.atuin.sh (or get in touch with your host)"
147        );
148    }
149
150    if status == StatusCode::TOO_MANY_REQUESTS {
151        bail!("Rate limited; please wait before doing that again");
152    }
153
154    if !status.is_success() {
155        if let Ok(error) = resp.json::<ErrorResponse>().await {
156            let reason = error.reason;
157
158            if status.is_client_error() {
159                bail!("Invalid request to the service: {status} - {reason}.")
160            }
161
162            bail!(
163                "There was an error with the atuin sync service, server error {status}: {reason}.\nIf the problem persists, contact the host"
164            )
165        }
166
167        bail!(
168            "There was an error with the atuin sync service: Status {status:?}.\nIf the problem persists, contact the host"
169        )
170    }
171
172    Ok(resp)
173}
174
175impl<'a> Client<'a> {
176    pub fn new(
177        sync_addr: &'a str,
178        session_token: &str,
179        connect_timeout: u64,
180        timeout: u64,
181    ) -> Result<Self> {
182        let mut headers = HeaderMap::new();
183        headers.insert(AUTHORIZATION, format!("Token {session_token}").parse()?);
184
185        // used for semver server check
186        headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?);
187
188        Ok(Client {
189            sync_addr,
190            client: reqwest::Client::builder()
191                .user_agent(APP_USER_AGENT)
192                .default_headers(headers)
193                .connect_timeout(Duration::new(connect_timeout, 0))
194                .timeout(Duration::new(timeout, 0))
195                .build()?,
196        })
197    }
198
199    pub async fn count(&self) -> Result<i64> {
200        let url = format!("{}/sync/count", self.sync_addr);
201        let url = Url::parse(url.as_str())?;
202
203        let resp = self.client.get(url).send().await?;
204        let resp = handle_resp_error(resp).await?;
205
206        if !ensure_version(&resp)? {
207            bail!("could not sync due to version mismatch");
208        }
209
210        if resp.status() != StatusCode::OK {
211            bail!("failed to get count (are you logged in?)");
212        }
213
214        let count = resp.json::<CountResponse>().await?;
215
216        Ok(count.count)
217    }
218
219    pub async fn status(&self) -> Result<StatusResponse> {
220        let url = format!("{}/sync/status", self.sync_addr);
221        let url = Url::parse(url.as_str())?;
222
223        let resp = self.client.get(url).send().await?;
224        let resp = handle_resp_error(resp).await?;
225
226        if !ensure_version(&resp)? {
227            bail!("could not sync due to version mismatch");
228        }
229
230        let status = resp.json::<StatusResponse>().await?;
231
232        Ok(status)
233    }
234
235    pub async fn me(&self) -> Result<MeResponse> {
236        let url = format!("{}/api/v0/me", self.sync_addr);
237        let url = Url::parse(url.as_str())?;
238
239        let resp = self.client.get(url).send().await?;
240        let resp = handle_resp_error(resp).await?;
241
242        let status = resp.json::<MeResponse>().await?;
243
244        Ok(status)
245    }
246
247    pub async fn get_history(
248        &self,
249        sync_ts: OffsetDateTime,
250        history_ts: OffsetDateTime,
251        host: Option<String>,
252    ) -> Result<SyncHistoryResponse> {
253        let host = host.unwrap_or_else(|| hash_str(&get_host_user()));
254
255        let url = format!(
256            "{}/sync/history?sync_ts={}&history_ts={}&host={}",
257            self.sync_addr,
258            urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()),
259            urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()),
260            host,
261        );
262
263        let resp = self.client.get(url).send().await?;
264        let resp = handle_resp_error(resp).await?;
265
266        let history = resp.json::<SyncHistoryResponse>().await?;
267        Ok(history)
268    }
269
270    pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
271        let url = format!("{}/history", self.sync_addr);
272        let url = Url::parse(url.as_str())?;
273
274        let resp = self.client.post(url).json(history).send().await?;
275        handle_resp_error(resp).await?;
276
277        Ok(())
278    }
279
280    pub async fn delete_history(&self, h: History) -> Result<()> {
281        let url = format!("{}/history", self.sync_addr);
282        let url = Url::parse(url.as_str())?;
283
284        let resp = self
285            .client
286            .delete(url)
287            .json(&DeleteHistoryRequest {
288                client_id: h.id.to_string(),
289            })
290            .send()
291            .await?;
292
293        handle_resp_error(resp).await?;
294
295        Ok(())
296    }
297
298    pub async fn delete_store(&self) -> Result<()> {
299        let url = format!("{}/api/v0/store", self.sync_addr);
300        let url = Url::parse(url.as_str())?;
301
302        let resp = self.client.delete(url).send().await?;
303
304        handle_resp_error(resp).await?;
305
306        Ok(())
307    }
308
309    pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
310        let url = format!("{}/api/v0/record", self.sync_addr);
311        let url = Url::parse(url.as_str())?;
312
313        debug!("uploading {} records to {url}", records.len());
314
315        let resp = self.client.post(url).json(records).send().await?;
316        handle_resp_error(resp).await?;
317
318        Ok(())
319    }
320
321    pub async fn next_records(
322        &self,
323        host: HostId,
324        tag: String,
325        start: RecordIdx,
326        count: u64,
327    ) -> Result<Vec<Record<EncryptedData>>> {
328        debug!(
329            "fetching record/s from host {}/{}/{}",
330            host.0.to_string(),
331            tag,
332            start
333        );
334
335        let url = format!(
336            "{}/api/v0/record/next?host={}&tag={}&count={}&start={}",
337            self.sync_addr, host.0, tag, count, start
338        );
339
340        let url = Url::parse(url.as_str())?;
341
342        let resp = self.client.get(url).send().await?;
343        let resp = handle_resp_error(resp).await?;
344
345        let records = resp.json::<Vec<Record<EncryptedData>>>().await?;
346
347        Ok(records)
348    }
349
350    pub async fn record_status(&self) -> Result<RecordStatus> {
351        let url = format!("{}/api/v0/record", self.sync_addr);
352        let url = Url::parse(url.as_str())?;
353
354        let resp = self.client.get(url).send().await?;
355        let resp = handle_resp_error(resp).await?;
356
357        if !ensure_version(&resp)? {
358            bail!("could not sync records due to version mismatch");
359        }
360
361        let index = resp.json().await?;
362
363        debug!("got remote index {:?}", index);
364
365        Ok(index)
366    }
367
368    pub async fn delete(&self) -> Result<()> {
369        let url = format!("{}/account", self.sync_addr);
370        let url = Url::parse(url.as_str())?;
371
372        let resp = self.client.delete(url).send().await?;
373
374        if resp.status() == 403 {
375            bail!("invalid login details");
376        } else if resp.status() == 200 {
377            Ok(())
378        } else {
379            bail!("Unknown error");
380        }
381    }
382
383    pub async fn change_password(
384        &self,
385        current_password: String,
386        new_password: String,
387    ) -> Result<()> {
388        let url = format!("{}/account/password", self.sync_addr);
389        let url = Url::parse(url.as_str())?;
390
391        let resp = self
392            .client
393            .patch(url)
394            .json(&ChangePasswordRequest {
395                current_password,
396                new_password,
397            })
398            .send()
399            .await?;
400
401        if resp.status() == 401 {
402            bail!("current password is incorrect")
403        } else if resp.status() == 403 {
404            bail!("invalid login details");
405        } else if resp.status() == 200 {
406            Ok(())
407        } else {
408            bail!("Unknown error");
409        }
410    }
411
412    // Either request a verification email if token is null, or validate a token
413    pub async fn verify(&self, token: Option<String>) -> Result<(bool, bool)> {
414        // could dedupe this a bit, but it's simple at the moment
415        let (email_sent, verified) = if let Some(token) = token {
416            let url = format!("{}/api/v0/account/verify", self.sync_addr);
417            let url = Url::parse(url.as_str())?;
418
419            let resp = self
420                .client
421                .post(url)
422                .json(&VerificationTokenRequest { token })
423                .send()
424                .await?;
425            let resp = handle_resp_error(resp).await?;
426            let resp = resp.json::<VerificationTokenResponse>().await?;
427
428            (false, resp.verified)
429        } else {
430            let url = format!("{}/api/v0/account/send-verification", self.sync_addr);
431            let url = Url::parse(url.as_str())?;
432
433            let resp = self.client.post(url).send().await?;
434            let resp = handle_resp_error(resp).await?;
435            let resp = resp.json::<SendVerificationResponse>().await?;
436
437            (resp.email_sent, resp.verified)
438        };
439
440        Ok((email_sent, verified))
441    }
442}