1use std::{
2 convert::TryFrom,
3 io::{Read, Write},
4 iter::{FromIterator, IntoIterator},
5 slice::Iter,
6 str::FromStr,
7};
8
9use coins_core::ser::ByteFormat;
10
11use crate::{primitives::KeyFingerprint, Bip32Error, BIP32_HARDEN};
12
13fn try_parse_index(s: &str) -> Result<u32, Bip32Error> {
14 let mut index_str = s.to_owned();
15 let harden = if s.ends_with('\'') || s.ends_with('h') {
16 index_str.pop();
17 true
18 } else {
19 false
20 };
21
22 index_str
23 .parse::<u32>()
24 .map(|v| if harden { harden_index(v) } else { v })
25 .map_err(|_| Bip32Error::MalformattedDerivation(s.to_owned()))
26}
27
28fn encode_index(idx: u32, harden: char) -> String {
29 let mut s = (idx % BIP32_HARDEN).to_string();
30 if idx >= BIP32_HARDEN {
31 s.push(harden);
32 }
33 s
34}
35
36pub const fn harden_index(index: u32) -> u32 {
38 index + BIP32_HARDEN
39}
40
41#[derive(Default, Debug, Clone, Eq, PartialEq)]
43pub struct DerivationPath(Vec<u32>);
44
45impl serde::Serialize for DerivationPath {
46 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
47 where
48 S: serde::Serializer,
49 {
50 serializer.serialize_str(&self.derivation_string())
51 }
52}
53
54impl<'de> serde::Deserialize<'de> for DerivationPath {
55 fn deserialize<D>(deserializer: D) -> Result<DerivationPath, D::Error>
56 where
57 D: serde::Deserializer<'de>,
58 {
59 let s: &str = serde::Deserialize::deserialize(deserializer)?;
60 s.parse::<DerivationPath>()
61 .map_err(|e| serde::de::Error::custom(e.to_string()))
62 }
63}
64
65impl DerivationPath {
66 #[doc(hidden)]
67 pub fn custom_string(&self, root: &str, joiner: char, harden: char) -> String {
68 std::iter::once(root.to_owned())
69 .chain(self.0.iter().map(|s| encode_index(*s, harden)))
70 .collect::<Vec<String>>()
71 .join(&joiner.to_string())
72 }
73
74 pub fn last(&self) -> Option<&u32> {
76 self.0.last()
77 }
78
79 pub fn derivation_string(&self) -> String {
81 self.custom_string("m", '/', '\'')
82 }
83
84 pub fn is_empty(&self) -> bool {
86 self.0.is_empty()
87 }
88
89 pub fn len(&self) -> usize {
91 self.0.len()
92 }
93
94 pub fn iter(&self) -> Iter<u32> {
96 self.0.iter()
97 }
98
99 pub fn starts_with(&self, other: &Self) -> bool {
101 self.0.starts_with(&other.0)
102 }
103
104 pub fn without_prefix(&self, prefix: &Self) -> Option<DerivationPath> {
107 if !self.starts_with(prefix) {
108 None
109 } else {
110 Some(self.0[prefix.len()..].to_vec().into())
111 }
112 }
113
114 pub fn last_hardened(&self) -> (usize, Option<u32>) {
118 match self.iter().rev().position(|v| *v >= BIP32_HARDEN) {
119 Some(rev_pos) => {
120 let pos = self.len() - rev_pos - 1;
121 (pos, Some(self.0[pos]))
122 }
123 None => (0, None),
124 }
125 }
126
127 pub fn resized(&self, size: usize, pad_with: u32) -> Self {
130 let mut child = self.clone();
131 child.0.resize(size, pad_with);
132 child
133 }
134
135 pub fn extended(&self, idx: u32) -> Self {
137 let mut child = self.clone();
138 child.0.push(idx);
139 child
140 }
141}
142
143impl From<&DerivationPath> for DerivationPath {
144 fn from(v: &DerivationPath) -> Self {
145 v.clone()
146 }
147}
148
149impl From<Vec<u32>> for DerivationPath {
150 fn from(v: Vec<u32>) -> Self {
151 Self(v)
152 }
153}
154
155impl From<&Vec<u32>> for DerivationPath {
156 fn from(v: &Vec<u32>) -> Self {
157 Self(v.clone())
158 }
159}
160
161impl From<&[u32]> for DerivationPath {
162 fn from(v: &[u32]) -> Self {
163 Self(Vec::from(v))
164 }
165}
166
167impl TryFrom<u32> for DerivationPath {
168 type Error = Bip32Error;
169
170 fn try_from(v: u32) -> Result<Self, Self::Error> {
171 Ok(Self(vec![v]))
172 }
173}
174
175impl TryFrom<&str> for DerivationPath {
176 type Error = Bip32Error;
177
178 fn try_from(v: &str) -> Result<Self, Self::Error> {
179 v.parse()
180 }
181}
182
183impl FromIterator<u32> for DerivationPath {
184 fn from_iter<T>(iter: T) -> Self
185 where
186 T: IntoIterator<Item = u32>,
187 {
188 Vec::from_iter(iter).into()
189 }
190}
191
192impl FromStr for DerivationPath {
193 type Err = Bip32Error;
194
195 fn from_str(s: &str) -> Result<Self, Self::Err> {
196 s.split('/')
197 .filter(|v| v != &"m")
198 .map(try_parse_index)
199 .collect::<Result<Vec<u32>, Bip32Error>>()
200 .map(|v| v.into())
201 .map_err(|_| Bip32Error::MalformattedDerivation(s.to_owned()))
202 }
203}
204
205#[derive(Debug, Clone, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
207pub struct KeyDerivation {
208 pub root: KeyFingerprint,
210 pub path: DerivationPath,
212}
213
214impl KeyDerivation {
215 pub fn same_root(&self, other: &Self) -> bool {
218 self.root == other.root
219 }
220
221 pub fn is_possible_ancestor_of(&self, other: &Self) -> bool {
224 self.same_root(other) && other.path.starts_with(&self.path)
225 }
226
227 pub fn path_to_descendant(&self, descendant: &Self) -> Option<DerivationPath> {
229 descendant.path.without_prefix(&self.path)
230 }
231
232 pub fn resized(&self, size: usize, pad_with: u32) -> Self {
235 Self {
236 root: self.root,
237 path: self.path.resized(size, pad_with),
238 }
239 }
240
241 pub fn extended(&self, idx: u32) -> Self {
243 Self {
244 root: self.root,
245 path: self.path.extended(idx),
246 }
247 }
248}
249
250impl ByteFormat for KeyDerivation {
251 type Error = Bip32Error;
252
253 fn serialized_length(&self) -> usize {
254 4 + 4 * self.path.len()
255 }
256
257 fn read_from<T>(_reader: &mut T) -> Result<Self, Self::Error>
258 where
259 T: Read,
260 Self: std::marker::Sized,
261 {
262 unimplemented!()
263 }
286
287 fn write_to<T>(&self, writer: &mut T) -> Result<usize, Self::Error>
288 where
289 T: Write,
290 {
291 let mut length = writer.write(&self.root.0)?;
292 for i in self.path.iter() {
293 length += writer.write(&i.to_le_bytes())?;
294 }
295 Ok(length)
296 }
297}
298
299#[cfg(test)]
300pub mod test {
301 use super::*;
302
303 #[test]
304 fn it_parses_index_strings() {
305 let cases = [("32", 32), ("32h", 32 + BIP32_HARDEN), ("0h", BIP32_HARDEN)];
306 for case in cases.iter() {
307 match try_parse_index(case.0) {
308 Ok(v) => assert_eq!(v, case.1),
309 Err(e) => panic!("unexpected error {}", e),
310 }
311 }
312 }
313
314 #[test]
315 fn it_handles_malformatted_indices() {
316 let cases = ["-", "h", "toast", "憂鬱"];
317 for case in cases.iter() {
318 match try_parse_index(case) {
319 Ok(_) => panic!("expected an error"),
320 Err(Bip32Error::MalformattedDerivation(e)) => assert_eq!(&e, case),
321 Err(e) => panic!("unexpected error {}", e),
322 }
323 }
324 }
325
326 #[test]
327 fn it_parses_derivation_strings() {
328 let cases = [
329 ("m/32", vec![32]),
330 ("m/32'", vec![32 + BIP32_HARDEN]),
331 ("m/0'/32/5/5/5", vec![BIP32_HARDEN, 32, 5, 5, 5]),
332 ("32", vec![32]),
333 ("32'", vec![32 + BIP32_HARDEN]),
334 ("0'/32/5/5/5", vec![BIP32_HARDEN, 32, 5, 5, 5]),
335 ];
336 for case in cases.iter() {
337 match case.0.parse::<DerivationPath>() {
338 Ok(v) => assert_eq!(v.0, case.1),
339 Err(e) => panic!("unexpected error {}", e),
340 }
341 }
342 }
343
344 #[test]
345 fn it_handles_malformatted_derivations() {
346 let cases = ["//", "m/", "-", "h", "toast", "憂鬱"];
347 for case in cases.iter() {
348 match case.parse::<DerivationPath>() {
349 Ok(_) => panic!("expected an error"),
350 Err(Bip32Error::MalformattedDerivation(e)) => assert_eq!(&e, case),
351 Err(e) => panic!("unexpected error {}", e),
352 }
353 }
354 }
355
356 #[test]
357 fn it_removes_prefixes_from_derivations() {
358 let cases = [
360 (
361 DerivationPath(vec![1, 2, 3]),
362 DerivationPath(vec![1]),
363 Some(DerivationPath(vec![2, 3])),
364 ),
365 (
366 vec![1, 2, 3].into(),
367 vec![1, 2].into(),
368 Some(vec![3].into()),
369 ),
370 (
371 (1u32..=3).collect(),
372 (1u32..=3).collect(),
373 Some((0..0).collect()),
374 ),
375 (DerivationPath(vec![1, 2, 3]), vec![1, 3].into(), None),
376 ];
377 for case in cases.iter() {
378 assert_eq!(case.0.without_prefix(&case.1), case.2);
379 }
380 }
381
382 #[test]
383 fn it_proudces_paths_from_strings() {
384 let cases = ["//", "m/", "-", "h", "toast", "憂鬱"];
385
386 for case in cases.iter() {
387 let path: Result<DerivationPath, _> = case.parse().map_err(Into::into);
388 match path {
389 Ok(_) => panic!("expected an error"),
390 Err(Bip32Error::MalformattedDerivation(e)) => assert_eq!(&e, case),
391 Err(e) => panic!("unexpected error {}", e),
392 }
393 }
394 }
395
396 #[test]
397 fn it_stringifies_derivation_paths() {
398 let cases = [
399 (DerivationPath(vec![1, 2, 3]), "m/1/2/3"),
400 (
401 vec![BIP32_HARDEN, BIP32_HARDEN, BIP32_HARDEN].into(),
402 "m/0'/0'/0'",
403 ),
404 ];
405 for case in cases.iter() {
406 assert_eq!(&case.0.derivation_string(), case.1);
407 }
408 }
409}