1use std::collections::BTreeMap;
6
7use atuin_client::record::sqlite_store::SqliteStore;
8use atuin_common::record::{DecryptedData, Host, HostId};
9use eyre::{bail, ensure, eyre, Result};
10
11use atuin_client::record::encryption::PASETO_V4;
12use atuin_client::record::store::Store;
13
14use crate::shell::Var;
15
16const DOTFILES_VAR_VERSION: &str = "v0";
17const DOTFILES_VAR_TAG: &str = "dotfiles-var";
18const DOTFILES_VAR_LEN: usize = 20000; #[derive(Debug, Clone, PartialEq, Eq)]
21pub enum VarRecord {
22 Create(Var), Delete(String), }
25
26impl VarRecord {
27 pub fn serialize(&self) -> Result<DecryptedData> {
28 use rmp::encode;
29
30 let mut output = vec![];
31
32 match self {
33 VarRecord::Create(env) => {
34 encode::write_u8(&mut output, 0)?; env.serialize(&mut output)?;
37 }
38 VarRecord::Delete(env) => {
39 encode::write_u8(&mut output, 1)?; encode::write_array_len(&mut output, 1)?; encode::write_str(&mut output, env.as_str())?;
43 }
44 }
45
46 Ok(DecryptedData(output))
47 }
48
49 pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> {
50 use rmp::decode;
51
52 fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
53 eyre!("{err:?}")
54 }
55
56 match version {
57 DOTFILES_VAR_VERSION => {
58 let mut bytes = decode::Bytes::new(&data.0);
59
60 let record_type = decode::read_u8(&mut bytes).map_err(error_report)?;
61
62 match record_type {
63 0 => {
65 let env = Var::deserialize(&mut bytes)?;
66 Ok(VarRecord::Create(env))
67 }
68
69 1 => {
71 let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
72 ensure!(
73 nfields == 1,
74 "too many entries in v0 dotfiles var delete record"
75 );
76
77 let bytes = bytes.remaining_slice();
78
79 let (key, bytes) =
80 decode::read_str_from_slice(bytes).map_err(error_report)?;
81
82 if !bytes.is_empty() {
83 bail!("trailing bytes in encoded dotfiles var record. malformed")
84 }
85
86 Ok(VarRecord::Delete(key.to_owned()))
87 }
88
89 n => {
90 bail!("unknown Dotfiles var record type {n}")
91 }
92 }
93 }
94 _ => {
95 bail!("unknown version {version:?}")
96 }
97 }
98 }
99}
100
101#[derive(Debug, Clone)]
102pub struct VarStore {
103 pub store: SqliteStore,
104 pub host_id: HostId,
105 pub encryption_key: [u8; 32],
106}
107
108impl VarStore {
109 pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> VarStore {
111 VarStore {
112 store,
113 host_id,
114 encryption_key,
115 }
116 }
117
118 pub async fn xonsh(&self) -> Result<String> {
119 let env = self.vars().await?;
120
121 let mut config = String::new();
122
123 for env in env {
124 config.push_str(&format!("${}={}\n", env.name, env.value));
125 }
126
127 Ok(config)
128 }
129
130 pub async fn fish(&self) -> Result<String> {
131 let env = self.vars().await?;
132
133 let mut config = String::new();
134
135 for env in env {
136 config.push_str(&format!("set -gx {} {}\n", env.name, env.value));
137 }
138
139 Ok(config)
140 }
141
142 pub async fn posix(&self) -> Result<String> {
143 let env = self.vars().await?;
144
145 let mut config = String::new();
146
147 for env in env {
148 if env.export {
149 config.push_str(&format!("export {}={}\n", env.name, env.value));
150 } else {
151 config.push_str(&format!("{}={}\n", env.name, env.value));
152 }
153 }
154
155 Ok(config)
156 }
157
158 pub async fn build(&self) -> Result<()> {
159 let dir = atuin_common::utils::dotfiles_cache_dir();
160 tokio::fs::create_dir_all(dir.clone()).await?;
161
162 let posix = self.posix().await?;
164 let xonsh = self.xonsh().await?;
165 let fsh = self.fish().await?;
166
167 let zsh = dir.join("vars.zsh");
171 let bash = dir.join("vars.bash");
172 let fish = dir.join("vars.fish");
173 let xsh = dir.join("vars.xsh");
174
175 tokio::fs::write(zsh, &posix).await?;
176 tokio::fs::write(bash, &posix).await?;
177 tokio::fs::write(fish, &fsh).await?;
178 tokio::fs::write(xsh, &xonsh).await?;
179
180 Ok(())
181 }
182
183 pub async fn set(&self, name: &str, value: &str, export: bool) -> Result<()> {
184 if name.len() + value.len() > DOTFILES_VAR_LEN {
185 return Err(eyre!(
186 "var record too large: max len {} bytes",
187 DOTFILES_VAR_LEN
188 ));
189 }
190
191 let record = VarRecord::Create(Var {
192 name: name.to_string(),
193 value: value.to_string(),
194 export,
195 });
196
197 let bytes = record.serialize()?;
198
199 let idx = self
200 .store
201 .last(self.host_id, DOTFILES_VAR_TAG)
202 .await?
203 .map_or(0, |entry| entry.idx + 1);
204
205 let record = atuin_common::record::Record::builder()
206 .host(Host::new(self.host_id))
207 .version(DOTFILES_VAR_VERSION.to_string())
208 .tag(DOTFILES_VAR_TAG.to_string())
209 .idx(idx)
210 .data(bytes)
211 .build();
212
213 self.store
214 .push(&record.encrypt::<PASETO_V4>(&self.encryption_key))
215 .await?;
216
217 self.build().await?;
219
220 Ok(())
221 }
222
223 pub async fn delete(&self, name: &str) -> Result<()> {
224 if name.len() > DOTFILES_VAR_LEN {
225 return Err(eyre!(
226 "var record too large: max len {} bytes",
227 DOTFILES_VAR_LEN,
228 ));
229 }
230
231 let record = VarRecord::Delete(name.to_string());
232
233 let bytes = record.serialize()?;
234
235 let idx = self
236 .store
237 .last(self.host_id, DOTFILES_VAR_TAG)
238 .await?
239 .map_or(0, |entry| entry.idx + 1);
240
241 let record = atuin_common::record::Record::builder()
242 .host(Host::new(self.host_id))
243 .version(DOTFILES_VAR_VERSION.to_string())
244 .tag(DOTFILES_VAR_TAG.to_string())
245 .idx(idx)
246 .data(bytes)
247 .build();
248
249 self.store
250 .push(&record.encrypt::<PASETO_V4>(&self.encryption_key))
251 .await?;
252
253 self.build().await?;
255
256 Ok(())
257 }
258
259 pub async fn vars(&self) -> Result<Vec<Var>> {
260 let mut build = BTreeMap::new();
261
262 let tagged = self.store.all_tagged(DOTFILES_VAR_TAG).await?;
264
265 for record in tagged {
266 let version = record.version.clone();
267
268 let decrypted = match version.as_str() {
269 DOTFILES_VAR_VERSION => record.decrypt::<PASETO_V4>(&self.encryption_key)?,
270 version => bail!("unknown version {version:?}"),
271 };
272
273 let ar = VarRecord::deserialize(&decrypted.data, version.as_str())?;
274
275 match ar {
276 VarRecord::Create(a) => {
277 build.insert(a.name.clone(), a);
278 }
279 VarRecord::Delete(d) => {
280 build.remove(&d);
281 }
282 }
283 }
284
285 Ok(build.into_values().collect())
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use rand::rngs::OsRng;
292
293 use atuin_client::record::sqlite_store::SqliteStore;
294
295 use crate::{shell::Var, store::test_local_timeout};
296
297 use super::{VarRecord, VarStore, DOTFILES_VAR_VERSION};
298 use crypto_secretbox::{KeyInit, XSalsa20Poly1305};
299
300 #[test]
301 fn encode_decode() {
302 let record = Var {
303 name: "BEEP".to_owned(),
304 value: "boop".to_owned(),
305 export: false,
306 };
307 let record = VarRecord::Create(record);
308
309 let snapshot = [
310 204, 0, 147, 164, 66, 69, 69, 80, 164, 98, 111, 111, 112, 194,
311 ];
312
313 let encoded = record.serialize().unwrap();
314 let decoded = VarRecord::deserialize(&encoded, DOTFILES_VAR_VERSION).unwrap();
315
316 assert_eq!(encoded.0, &snapshot);
317 assert_eq!(decoded, record);
318 }
319
320 #[tokio::test]
321 async fn build_vars() {
322 let store = SqliteStore::new(":memory:", test_local_timeout())
323 .await
324 .unwrap();
325 let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into();
326 let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7());
327
328 let env = VarStore::new(store, host_id, key);
329
330 env.set("BEEP", "boop", false).await.unwrap();
331 env.set("HOMEBREW_NO_AUTO_UPDATE", "1", true).await.unwrap();
332
333 let mut env_vars = env.vars().await.unwrap();
334
335 env_vars.sort_by_key(|a| a.name.clone());
336
337 assert_eq!(env_vars.len(), 2);
338
339 assert_eq!(
340 env_vars[0],
341 Var {
342 name: String::from("BEEP"),
343 value: String::from("boop"),
344 export: false,
345 }
346 );
347
348 assert_eq!(
349 env_vars[1],
350 Var {
351 name: String::from("HOMEBREW_NO_AUTO_UPDATE"),
352 value: String::from("1"),
353 export: true,
354 }
355 );
356 }
357}