1use std::collections::HashMap;
2use std::env;
3use std::time::Duration;
4
5use eyre::{Result, bail};
6use reqwest::{
7 Response, StatusCode, Url,
8 header::{AUTHORIZATION, HeaderMap, USER_AGENT},
9};
10
11use atuin_common::{
12 api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION},
13 record::{EncryptedData, HostId, Record, RecordIdx},
14};
15use atuin_common::{
16 api::{
17 AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest,
18 ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse,
19 SendVerificationResponse, StatusResponse, SyncHistoryResponse, VerificationTokenRequest,
20 VerificationTokenResponse,
21 },
22 record::RecordStatus,
23};
24
25use semver::Version;
26use time::OffsetDateTime;
27use time::format_description::well_known::Rfc3339;
28
29use crate::{history::History, sync::hash_str, utils::get_host_user};
30
31static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),);
32
33pub struct Client<'a> {
34 sync_addr: &'a str,
35 client: reqwest::Client,
36}
37
38pub async fn register(
39 address: &str,
40 username: &str,
41 email: &str,
42 password: &str,
43) -> Result<RegisterResponse> {
44 let mut map = HashMap::new();
45 map.insert("username", username);
46 map.insert("email", email);
47 map.insert("password", password);
48
49 let url = format!("{address}/user/{username}");
50 let resp = reqwest::get(url).await?;
51
52 if resp.status().is_success() {
53 bail!("username already in use");
54 }
55
56 let url = format!("{address}/register");
57 let client = reqwest::Client::new();
58 let resp = client
59 .post(url)
60 .header(USER_AGENT, APP_USER_AGENT)
61 .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION)
62 .json(&map)
63 .send()
64 .await?;
65 let resp = handle_resp_error(resp).await?;
66
67 if !ensure_version(&resp)? {
68 bail!("could not register user due to version mismatch");
69 }
70
71 let session = resp.json::<RegisterResponse>().await?;
72 Ok(session)
73}
74
75pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> {
76 let url = format!("{address}/login");
77 let client = reqwest::Client::new();
78
79 let resp = client
80 .post(url)
81 .header(USER_AGENT, APP_USER_AGENT)
82 .json(&req)
83 .send()
84 .await?;
85 let resp = handle_resp_error(resp).await?;
86
87 if !ensure_version(&resp)? {
88 bail!("Could not login due to version mismatch");
89 }
90
91 let session = resp.json::<LoginResponse>().await?;
92 Ok(session)
93}
94
95#[cfg(feature = "check-update")]
96pub async fn latest_version() -> Result<Version> {
97 use atuin_common::api::IndexResponse;
98
99 let url = "https://api.atuin.sh";
100 let client = reqwest::Client::new();
101
102 let resp = client
103 .get(url)
104 .header(USER_AGENT, APP_USER_AGENT)
105 .send()
106 .await?;
107 let resp = handle_resp_error(resp).await?;
108
109 let index = resp.json::<IndexResponse>().await?;
110 let version = Version::parse(index.version.as_str())?;
111
112 Ok(version)
113}
114
115pub fn ensure_version(response: &Response) -> Result<bool> {
116 let version = response.headers().get(ATUIN_HEADER_VERSION);
117
118 let version = if let Some(version) = version {
119 match version.to_str() {
120 Ok(v) => Version::parse(v),
121 Err(e) => bail!("failed to parse server version: {:?}", e),
122 }
123 } else {
124 bail!("Server not reporting its version: it is either too old or unhealthy");
125 }?;
126
127 if version.major < ATUIN_VERSION.major {
129 println!(
130 "Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin"
131 );
132 println!("Client: {}", ATUIN_CARGO_VERSION);
133 println!("Server: {}", version);
134
135 return Ok(false);
136 }
137
138 Ok(true)
139}
140
141async fn handle_resp_error(resp: Response) -> Result<Response> {
142 let status = resp.status();
143
144 if status == StatusCode::SERVICE_UNAVAILABLE {
145 bail!(
146 "Service unavailable: check https://status.atuin.sh (or get in touch with your host)"
147 );
148 }
149
150 if status == StatusCode::TOO_MANY_REQUESTS {
151 bail!("Rate limited; please wait before doing that again");
152 }
153
154 if !status.is_success() {
155 if let Ok(error) = resp.json::<ErrorResponse>().await {
156 let reason = error.reason;
157
158 if status.is_client_error() {
159 bail!("Invalid request to the service: {status} - {reason}.")
160 }
161
162 bail!(
163 "There was an error with the atuin sync service, server error {status}: {reason}.\nIf the problem persists, contact the host"
164 )
165 }
166
167 bail!(
168 "There was an error with the atuin sync service: Status {status:?}.\nIf the problem persists, contact the host"
169 )
170 }
171
172 Ok(resp)
173}
174
175impl<'a> Client<'a> {
176 pub fn new(
177 sync_addr: &'a str,
178 session_token: &str,
179 connect_timeout: u64,
180 timeout: u64,
181 ) -> Result<Self> {
182 let mut headers = HeaderMap::new();
183 headers.insert(AUTHORIZATION, format!("Token {session_token}").parse()?);
184
185 headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?);
187
188 Ok(Client {
189 sync_addr,
190 client: reqwest::Client::builder()
191 .user_agent(APP_USER_AGENT)
192 .default_headers(headers)
193 .connect_timeout(Duration::new(connect_timeout, 0))
194 .timeout(Duration::new(timeout, 0))
195 .build()?,
196 })
197 }
198
199 pub async fn count(&self) -> Result<i64> {
200 let url = format!("{}/sync/count", self.sync_addr);
201 let url = Url::parse(url.as_str())?;
202
203 let resp = self.client.get(url).send().await?;
204 let resp = handle_resp_error(resp).await?;
205
206 if !ensure_version(&resp)? {
207 bail!("could not sync due to version mismatch");
208 }
209
210 if resp.status() != StatusCode::OK {
211 bail!("failed to get count (are you logged in?)");
212 }
213
214 let count = resp.json::<CountResponse>().await?;
215
216 Ok(count.count)
217 }
218
219 pub async fn status(&self) -> Result<StatusResponse> {
220 let url = format!("{}/sync/status", self.sync_addr);
221 let url = Url::parse(url.as_str())?;
222
223 let resp = self.client.get(url).send().await?;
224 let resp = handle_resp_error(resp).await?;
225
226 if !ensure_version(&resp)? {
227 bail!("could not sync due to version mismatch");
228 }
229
230 let status = resp.json::<StatusResponse>().await?;
231
232 Ok(status)
233 }
234
235 pub async fn me(&self) -> Result<MeResponse> {
236 let url = format!("{}/api/v0/me", self.sync_addr);
237 let url = Url::parse(url.as_str())?;
238
239 let resp = self.client.get(url).send().await?;
240 let resp = handle_resp_error(resp).await?;
241
242 let status = resp.json::<MeResponse>().await?;
243
244 Ok(status)
245 }
246
247 pub async fn get_history(
248 &self,
249 sync_ts: OffsetDateTime,
250 history_ts: OffsetDateTime,
251 host: Option<String>,
252 ) -> Result<SyncHistoryResponse> {
253 let host = host.unwrap_or_else(|| hash_str(&get_host_user()));
254
255 let url = format!(
256 "{}/sync/history?sync_ts={}&history_ts={}&host={}",
257 self.sync_addr,
258 urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()),
259 urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()),
260 host,
261 );
262
263 let resp = self.client.get(url).send().await?;
264 let resp = handle_resp_error(resp).await?;
265
266 let history = resp.json::<SyncHistoryResponse>().await?;
267 Ok(history)
268 }
269
270 pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
271 let url = format!("{}/history", self.sync_addr);
272 let url = Url::parse(url.as_str())?;
273
274 let resp = self.client.post(url).json(history).send().await?;
275 handle_resp_error(resp).await?;
276
277 Ok(())
278 }
279
280 pub async fn delete_history(&self, h: History) -> Result<()> {
281 let url = format!("{}/history", self.sync_addr);
282 let url = Url::parse(url.as_str())?;
283
284 let resp = self
285 .client
286 .delete(url)
287 .json(&DeleteHistoryRequest {
288 client_id: h.id.to_string(),
289 })
290 .send()
291 .await?;
292
293 handle_resp_error(resp).await?;
294
295 Ok(())
296 }
297
298 pub async fn delete_store(&self) -> Result<()> {
299 let url = format!("{}/api/v0/store", self.sync_addr);
300 let url = Url::parse(url.as_str())?;
301
302 let resp = self.client.delete(url).send().await?;
303
304 handle_resp_error(resp).await?;
305
306 Ok(())
307 }
308
309 pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
310 let url = format!("{}/api/v0/record", self.sync_addr);
311 let url = Url::parse(url.as_str())?;
312
313 debug!("uploading {} records to {url}", records.len());
314
315 let resp = self.client.post(url).json(records).send().await?;
316 handle_resp_error(resp).await?;
317
318 Ok(())
319 }
320
321 pub async fn next_records(
322 &self,
323 host: HostId,
324 tag: String,
325 start: RecordIdx,
326 count: u64,
327 ) -> Result<Vec<Record<EncryptedData>>> {
328 debug!(
329 "fetching record/s from host {}/{}/{}",
330 host.0.to_string(),
331 tag,
332 start
333 );
334
335 let url = format!(
336 "{}/api/v0/record/next?host={}&tag={}&count={}&start={}",
337 self.sync_addr, host.0, tag, count, start
338 );
339
340 let url = Url::parse(url.as_str())?;
341
342 let resp = self.client.get(url).send().await?;
343 let resp = handle_resp_error(resp).await?;
344
345 let records = resp.json::<Vec<Record<EncryptedData>>>().await?;
346
347 Ok(records)
348 }
349
350 pub async fn record_status(&self) -> Result<RecordStatus> {
351 let url = format!("{}/api/v0/record", self.sync_addr);
352 let url = Url::parse(url.as_str())?;
353
354 let resp = self.client.get(url).send().await?;
355 let resp = handle_resp_error(resp).await?;
356
357 if !ensure_version(&resp)? {
358 bail!("could not sync records due to version mismatch");
359 }
360
361 let index = resp.json().await?;
362
363 debug!("got remote index {:?}", index);
364
365 Ok(index)
366 }
367
368 pub async fn delete(&self) -> Result<()> {
369 let url = format!("{}/account", self.sync_addr);
370 let url = Url::parse(url.as_str())?;
371
372 let resp = self.client.delete(url).send().await?;
373
374 if resp.status() == 403 {
375 bail!("invalid login details");
376 } else if resp.status() == 200 {
377 Ok(())
378 } else {
379 bail!("Unknown error");
380 }
381 }
382
383 pub async fn change_password(
384 &self,
385 current_password: String,
386 new_password: String,
387 ) -> Result<()> {
388 let url = format!("{}/account/password", self.sync_addr);
389 let url = Url::parse(url.as_str())?;
390
391 let resp = self
392 .client
393 .patch(url)
394 .json(&ChangePasswordRequest {
395 current_password,
396 new_password,
397 })
398 .send()
399 .await?;
400
401 if resp.status() == 401 {
402 bail!("current password is incorrect")
403 } else if resp.status() == 403 {
404 bail!("invalid login details");
405 } else if resp.status() == 200 {
406 Ok(())
407 } else {
408 bail!("Unknown error");
409 }
410 }
411
412 pub async fn verify(&self, token: Option<String>) -> Result<(bool, bool)> {
414 let (email_sent, verified) = if let Some(token) = token {
416 let url = format!("{}/api/v0/account/verify", self.sync_addr);
417 let url = Url::parse(url.as_str())?;
418
419 let resp = self
420 .client
421 .post(url)
422 .json(&VerificationTokenRequest { token })
423 .send()
424 .await?;
425 let resp = handle_resp_error(resp).await?;
426 let resp = resp.json::<VerificationTokenResponse>().await?;
427
428 (false, resp.verified)
429 } else {
430 let url = format!("{}/api/v0/account/send-verification", self.sync_addr);
431 let url = Url::parse(url.as_str())?;
432
433 let resp = self.client.post(url).send().await?;
434 let resp = handle_resp_error(resp).await?;
435 let resp = resp.json::<SendVerificationResponse>().await?;
436
437 (resp.email_sent, resp.verified)
438 };
439
440 Ok((email_sent, verified))
441 }
442}