atuin_dotfiles/store/
var.rs

1/// Store for shell vars
2/// I should abstract this and reuse code between the alias/env stores
3/// This is easier for now
4/// Once I have two implementations, building a common base is much easier.
5use std::collections::BTreeMap;
6
7use atuin_client::record::sqlite_store::SqliteStore;
8use atuin_common::record::{DecryptedData, Host, HostId};
9use eyre::{bail, ensure, eyre, Result};
10
11use atuin_client::record::encryption::PASETO_V4;
12use atuin_client::record::store::Store;
13
14use crate::shell::Var;
15
16const DOTFILES_VAR_VERSION: &str = "v0";
17const DOTFILES_VAR_TAG: &str = "dotfiles-var";
18const DOTFILES_VAR_LEN: usize = 20000; // 20kb max total len, way more than should be needed.
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum VarRecord {
22    Create(Var),    // create a full record
23    Delete(String), // delete by name
24}
25
26impl VarRecord {
27    pub fn serialize(&self) -> Result<DecryptedData> {
28        use rmp::encode;
29
30        let mut output = vec![];
31
32        match self {
33            VarRecord::Create(env) => {
34                encode::write_u8(&mut output, 0)?; // create
35
36                env.serialize(&mut output)?;
37            }
38            VarRecord::Delete(env) => {
39                encode::write_u8(&mut output, 1)?; // delete
40                encode::write_array_len(&mut output, 1)?; // 1 field
41
42                encode::write_str(&mut output, env.as_str())?;
43            }
44        }
45
46        Ok(DecryptedData(output))
47    }
48
49    pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> {
50        use rmp::decode;
51
52        fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
53            eyre!("{err:?}")
54        }
55
56        match version {
57            DOTFILES_VAR_VERSION => {
58                let mut bytes = decode::Bytes::new(&data.0);
59
60                let record_type = decode::read_u8(&mut bytes).map_err(error_report)?;
61
62                match record_type {
63                    // create
64                    0 => {
65                        let env = Var::deserialize(&mut bytes)?;
66                        Ok(VarRecord::Create(env))
67                    }
68
69                    // delete
70                    1 => {
71                        let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
72                        ensure!(
73                            nfields == 1,
74                            "too many entries in v0 dotfiles var delete record"
75                        );
76
77                        let bytes = bytes.remaining_slice();
78
79                        let (key, bytes) =
80                            decode::read_str_from_slice(bytes).map_err(error_report)?;
81
82                        if !bytes.is_empty() {
83                            bail!("trailing bytes in encoded dotfiles var record. malformed")
84                        }
85
86                        Ok(VarRecord::Delete(key.to_owned()))
87                    }
88
89                    n => {
90                        bail!("unknown Dotfiles var record type {n}")
91                    }
92                }
93            }
94            _ => {
95                bail!("unknown version {version:?}")
96            }
97        }
98    }
99}
100
101#[derive(Debug, Clone)]
102pub struct VarStore {
103    pub store: SqliteStore,
104    pub host_id: HostId,
105    pub encryption_key: [u8; 32],
106}
107
108impl VarStore {
109    // will want to init the actual kv store when that is done
110    pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> VarStore {
111        VarStore {
112            store,
113            host_id,
114            encryption_key,
115        }
116    }
117
118    pub async fn xonsh(&self) -> Result<String> {
119        let env = self.vars().await?;
120
121        let mut config = String::new();
122
123        for env in env {
124            config.push_str(&format!("${}={}\n", env.name, env.value));
125        }
126
127        Ok(config)
128    }
129
130    pub async fn fish(&self) -> Result<String> {
131        let env = self.vars().await?;
132
133        let mut config = String::new();
134
135        for env in env {
136            config.push_str(&format!("set -gx {} {}\n", env.name, env.value));
137        }
138
139        Ok(config)
140    }
141
142    pub async fn posix(&self) -> Result<String> {
143        let env = self.vars().await?;
144
145        let mut config = String::new();
146
147        for env in env {
148            if env.export {
149                config.push_str(&format!("export {}={}\n", env.name, env.value));
150            } else {
151                config.push_str(&format!("{}={}\n", env.name, env.value));
152            }
153        }
154
155        Ok(config)
156    }
157
158    pub async fn build(&self) -> Result<()> {
159        let dir = atuin_common::utils::dotfiles_cache_dir();
160        tokio::fs::create_dir_all(dir.clone()).await?;
161
162        // Build for all supported shells
163        let posix = self.posix().await?;
164        let xonsh = self.xonsh().await?;
165        let fsh = self.fish().await?;
166
167        // All the same contents, maybe optimize in the future or perhaps there will be quirks
168        // per-shell
169        // I'd prefer separation atm
170        let zsh = dir.join("vars.zsh");
171        let bash = dir.join("vars.bash");
172        let fish = dir.join("vars.fish");
173        let xsh = dir.join("vars.xsh");
174
175        tokio::fs::write(zsh, &posix).await?;
176        tokio::fs::write(bash, &posix).await?;
177        tokio::fs::write(fish, &fsh).await?;
178        tokio::fs::write(xsh, &xonsh).await?;
179
180        Ok(())
181    }
182
183    pub async fn set(&self, name: &str, value: &str, export: bool) -> Result<()> {
184        if name.len() + value.len() > DOTFILES_VAR_LEN {
185            return Err(eyre!(
186                "var record too large: max len {} bytes",
187                DOTFILES_VAR_LEN
188            ));
189        }
190
191        let record = VarRecord::Create(Var {
192            name: name.to_string(),
193            value: value.to_string(),
194            export,
195        });
196
197        let bytes = record.serialize()?;
198
199        let idx = self
200            .store
201            .last(self.host_id, DOTFILES_VAR_TAG)
202            .await?
203            .map_or(0, |entry| entry.idx + 1);
204
205        let record = atuin_common::record::Record::builder()
206            .host(Host::new(self.host_id))
207            .version(DOTFILES_VAR_VERSION.to_string())
208            .tag(DOTFILES_VAR_TAG.to_string())
209            .idx(idx)
210            .data(bytes)
211            .build();
212
213        self.store
214            .push(&record.encrypt::<PASETO_V4>(&self.encryption_key))
215            .await?;
216
217        // set mutates shell config, so build again
218        self.build().await?;
219
220        Ok(())
221    }
222
223    pub async fn delete(&self, name: &str) -> Result<()> {
224        if name.len() > DOTFILES_VAR_LEN {
225            return Err(eyre!(
226                "var record too large: max len {} bytes",
227                DOTFILES_VAR_LEN,
228            ));
229        }
230
231        let record = VarRecord::Delete(name.to_string());
232
233        let bytes = record.serialize()?;
234
235        let idx = self
236            .store
237            .last(self.host_id, DOTFILES_VAR_TAG)
238            .await?
239            .map_or(0, |entry| entry.idx + 1);
240
241        let record = atuin_common::record::Record::builder()
242            .host(Host::new(self.host_id))
243            .version(DOTFILES_VAR_VERSION.to_string())
244            .tag(DOTFILES_VAR_TAG.to_string())
245            .idx(idx)
246            .data(bytes)
247            .build();
248
249        self.store
250            .push(&record.encrypt::<PASETO_V4>(&self.encryption_key))
251            .await?;
252
253        // delete mutates shell config, so build again
254        self.build().await?;
255
256        Ok(())
257    }
258
259    pub async fn vars(&self) -> Result<Vec<Var>> {
260        let mut build = BTreeMap::new();
261
262        // this is sorted, oldest to newest
263        let tagged = self.store.all_tagged(DOTFILES_VAR_TAG).await?;
264
265        for record in tagged {
266            let version = record.version.clone();
267
268            let decrypted = match version.as_str() {
269                DOTFILES_VAR_VERSION => record.decrypt::<PASETO_V4>(&self.encryption_key)?,
270                version => bail!("unknown version {version:?}"),
271            };
272
273            let ar = VarRecord::deserialize(&decrypted.data, version.as_str())?;
274
275            match ar {
276                VarRecord::Create(a) => {
277                    build.insert(a.name.clone(), a);
278                }
279                VarRecord::Delete(d) => {
280                    build.remove(&d);
281                }
282            }
283        }
284
285        Ok(build.into_values().collect())
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use rand::rngs::OsRng;
292
293    use atuin_client::record::sqlite_store::SqliteStore;
294
295    use crate::{shell::Var, store::test_local_timeout};
296
297    use super::{VarRecord, VarStore, DOTFILES_VAR_VERSION};
298    use crypto_secretbox::{KeyInit, XSalsa20Poly1305};
299
300    #[test]
301    fn encode_decode() {
302        let record = Var {
303            name: "BEEP".to_owned(),
304            value: "boop".to_owned(),
305            export: false,
306        };
307        let record = VarRecord::Create(record);
308
309        let snapshot = [
310            204, 0, 147, 164, 66, 69, 69, 80, 164, 98, 111, 111, 112, 194,
311        ];
312
313        let encoded = record.serialize().unwrap();
314        let decoded = VarRecord::deserialize(&encoded, DOTFILES_VAR_VERSION).unwrap();
315
316        assert_eq!(encoded.0, &snapshot);
317        assert_eq!(decoded, record);
318    }
319
320    #[tokio::test]
321    async fn build_vars() {
322        let store = SqliteStore::new(":memory:", test_local_timeout())
323            .await
324            .unwrap();
325        let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into();
326        let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7());
327
328        let env = VarStore::new(store, host_id, key);
329
330        env.set("BEEP", "boop", false).await.unwrap();
331        env.set("HOMEBREW_NO_AUTO_UPDATE", "1", true).await.unwrap();
332
333        let mut env_vars = env.vars().await.unwrap();
334
335        env_vars.sort_by_key(|a| a.name.clone());
336
337        assert_eq!(env_vars.len(), 2);
338
339        assert_eq!(
340            env_vars[0],
341            Var {
342                name: String::from("BEEP"),
343                value: String::from("boop"),
344                export: false,
345            }
346        );
347
348        assert_eq!(
349            env_vars[1],
350            Var {
351                name: String::from("HOMEBREW_NO_AUTO_UPDATE"),
352                value: String::from("1"),
353                export: true,
354            }
355        );
356    }
357}