1use std::collections::BTreeMap;
2
3use atuin_common::record::{DecryptedData, Host, HostId};
4use eyre::{bail, ensure, eyre, Result};
5use serde::Deserialize;
6
7use crate::record::encryption::PASETO_V4;
8use crate::record::store::Store;
9
10const KV_VERSION: &str = "v0";
11const KV_TAG: &str = "kv";
12const KV_VAL_MAX_LEN: usize = 100 * 1024;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct KvRecord {
16 pub namespace: String,
17 pub key: String,
18 pub value: String,
19}
20
21impl KvRecord {
22 pub fn serialize(&self) -> Result<DecryptedData> {
23 use rmp::encode;
24
25 let mut output = vec![];
26
27 encode::write_array_len(&mut output, 3)?;
29
30 encode::write_str(&mut output, &self.namespace)?;
31 encode::write_str(&mut output, &self.key)?;
32 encode::write_str(&mut output, &self.value)?;
33
34 Ok(DecryptedData(output))
35 }
36
37 pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> {
38 use rmp::decode;
39
40 fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
41 eyre!("{err:?}")
42 }
43
44 match version {
45 KV_VERSION => {
46 let mut bytes = decode::Bytes::new(&data.0);
47
48 let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
49 ensure!(nfields == 3, "too many entries in v0 kv record");
50
51 let bytes = bytes.remaining_slice();
52
53 let (namespace, bytes) =
54 decode::read_str_from_slice(bytes).map_err(error_report)?;
55 let (key, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
56 let (value, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
57
58 if !bytes.is_empty() {
59 bail!("trailing bytes in encoded kvrecord. malformed")
60 }
61
62 Ok(KvRecord {
63 namespace: namespace.to_owned(),
64 key: key.to_owned(),
65 value: value.to_owned(),
66 })
67 }
68 _ => {
69 bail!("unknown version {version:?}")
70 }
71 }
72 }
73}
74
75#[derive(Debug, Clone, Deserialize)]
76pub struct KvStore;
77
78impl Default for KvStore {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl KvStore {
85 pub fn new() -> KvStore {
87 KvStore {}
88 }
89
90 pub async fn set(
91 &self,
92 store: &(impl Store + Send + Sync),
93 encryption_key: &[u8; 32],
94 host_id: HostId,
95 namespace: &str,
96 key: &str,
97 value: &str,
98 ) -> Result<()> {
99 if value.len() > KV_VAL_MAX_LEN {
100 return Err(eyre!(
101 "kv value too large: max len {} bytes",
102 KV_VAL_MAX_LEN
103 ));
104 }
105
106 let record = KvRecord {
107 namespace: namespace.to_string(),
108 key: key.to_string(),
109 value: value.to_string(),
110 };
111
112 let bytes = record.serialize()?;
113
114 let idx = store
115 .last(host_id, KV_TAG)
116 .await?
117 .map_or(0, |entry| entry.idx + 1);
118
119 let record = atuin_common::record::Record::builder()
120 .host(Host::new(host_id))
121 .version(KV_VERSION.to_string())
122 .tag(KV_TAG.to_string())
123 .idx(idx)
124 .data(bytes)
125 .build();
126
127 store
128 .push(&record.encrypt::<PASETO_V4>(encryption_key))
129 .await?;
130
131 Ok(())
132 }
133
134 pub async fn get(
137 &self,
138 store: &impl Store,
139 encryption_key: &[u8; 32],
140 namespace: &str,
141 key: &str,
142 ) -> Result<Option<KvRecord>> {
143 let map = self.build_kv(store, encryption_key).await?;
145
146 let res = map.get(namespace);
147
148 if let Some(ns) = res {
149 let value = ns.get(key);
150
151 Ok(value.cloned())
152 } else {
153 Ok(None)
154 }
155 }
156
157 pub async fn build_kv(
162 &self,
163 store: &impl Store,
164 encryption_key: &[u8; 32],
165 ) -> Result<BTreeMap<String, BTreeMap<String, KvRecord>>> {
166 let mut map = BTreeMap::new();
167
168 let tagged = store.all_tagged(KV_TAG).await?;
172
173 for record in tagged {
177 let decrypted = match record.version.as_str() {
178 KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?,
179 version => bail!("unknown version {version:?}"),
180 };
181
182 let kv = KvRecord::deserialize(&decrypted.data, KV_VERSION)?;
183
184 let ns = map
185 .entry(kv.namespace.clone())
186 .or_insert_with(BTreeMap::new);
187
188 ns.insert(kv.key.clone(), kv);
189 }
190
191 Ok(map)
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use crypto_secretbox::{KeyInit, XSalsa20Poly1305};
198 use rand::rngs::OsRng;
199
200 use crate::record::sqlite_store::SqliteStore;
201 use crate::settings::test_local_timeout;
202
203 use super::{KvRecord, KvStore, KV_VERSION};
204
205 #[test]
206 fn encode_decode() {
207 let kv = KvRecord {
208 namespace: "foo".to_owned(),
209 key: "bar".to_owned(),
210 value: "baz".to_owned(),
211 };
212 let snapshot = [
213 0x93, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xa3, b'b', b'a', b'z',
214 ];
215
216 let encoded = kv.serialize().unwrap();
217 let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap();
218
219 assert_eq!(encoded.0, &snapshot);
220 assert_eq!(decoded, kv);
221 }
222
223 #[tokio::test]
224 async fn build_kv() {
225 let mut store = SqliteStore::new(":memory:", test_local_timeout())
226 .await
227 .unwrap();
228 let kv = KvStore::new();
229 let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into();
230 let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7());
231
232 kv.set(&mut store, &key, host_id, "test-kv", "foo", "bar")
233 .await
234 .unwrap();
235
236 kv.set(&mut store, &key, host_id, "test-kv", "1", "2")
237 .await
238 .unwrap();
239
240 let map = kv.build_kv(&store, &key).await.unwrap();
241
242 assert_eq!(
243 *map.get("test-kv")
244 .expect("map namespace not set")
245 .get("foo")
246 .expect("map key not set"),
247 KvRecord {
248 namespace: String::from("test-kv"),
249 key: String::from("foo"),
250 value: String::from("bar")
251 }
252 );
253
254 assert_eq!(
255 *map.get("test-kv")
256 .expect("map namespace not set")
257 .get("1")
258 .expect("map key not set"),
259 KvRecord {
260 namespace: String::from("test-kv"),
261 key: String::from("1"),
262 value: String::from("2")
263 }
264 );
265 }
266}