coins_bip32/
path.rs

1use std::{
2    convert::TryFrom,
3    io::{Read, Write},
4    iter::{FromIterator, IntoIterator},
5    slice::Iter,
6    str::FromStr,
7};
8
9use coins_core::ser::ByteFormat;
10
11use crate::{primitives::KeyFingerprint, Bip32Error, BIP32_HARDEN};
12
13fn try_parse_index(s: &str) -> Result<u32, Bip32Error> {
14    let mut index_str = s.to_owned();
15    let harden = if s.ends_with('\'') || s.ends_with('h') {
16        index_str.pop();
17        true
18    } else {
19        false
20    };
21
22    index_str
23        .parse::<u32>()
24        .map(|v| if harden { harden_index(v) } else { v })
25        .map_err(|_| Bip32Error::MalformattedDerivation(s.to_owned()))
26}
27
28fn encode_index(idx: u32, harden: char) -> String {
29    let mut s = (idx % BIP32_HARDEN).to_string();
30    if idx >= BIP32_HARDEN {
31        s.push(harden);
32    }
33    s
34}
35
36/// Converts an raw index to hardened
37pub const fn harden_index(index: u32) -> u32 {
38    index + BIP32_HARDEN
39}
40
41/// A Bip32 derivation path
42#[derive(Default, Debug, Clone, Eq, PartialEq)]
43pub struct DerivationPath(Vec<u32>);
44
45impl serde::Serialize for DerivationPath {
46    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
47    where
48        S: serde::Serializer,
49    {
50        serializer.serialize_str(&self.derivation_string())
51    }
52}
53
54impl<'de> serde::Deserialize<'de> for DerivationPath {
55    fn deserialize<D>(deserializer: D) -> Result<DerivationPath, D::Error>
56    where
57        D: serde::Deserializer<'de>,
58    {
59        let s: &str = serde::Deserialize::deserialize(deserializer)?;
60        s.parse::<DerivationPath>()
61            .map_err(|e| serde::de::Error::custom(e.to_string()))
62    }
63}
64
65impl DerivationPath {
66    #[doc(hidden)]
67    pub fn custom_string(&self, root: &str, joiner: char, harden: char) -> String {
68        std::iter::once(root.to_owned())
69            .chain(self.0.iter().map(|s| encode_index(*s, harden)))
70            .collect::<Vec<String>>()
71            .join(&joiner.to_string())
72    }
73
74    /// Return the last index in the path. None if the path is the root.
75    pub fn last(&self) -> Option<&u32> {
76        self.0.last()
77    }
78
79    /// Converts the path to a standard bip32 string. e.g `"m/44'/0'/0/32"`.
80    pub fn derivation_string(&self) -> String {
81        self.custom_string("m", '/', '\'')
82    }
83
84    /// Returns `True` if there are no indices in the path
85    pub fn is_empty(&self) -> bool {
86        self.0.is_empty()
87    }
88
89    /// The number of derivations in the path
90    pub fn len(&self) -> usize {
91        self.0.len()
92    }
93
94    /// Make an iterator over the path indices
95    pub fn iter(&self) -> Iter<u32> {
96        self.0.iter()
97    }
98
99    /// `true` if `other` is a prefix of `self`
100    pub fn starts_with(&self, other: &Self) -> bool {
101        self.0.starts_with(&other.0)
102    }
103
104    /// Remove a prefix from a derivation. Return a new DerivationPath without the prefix.
105    /// This is useful for determining the path to rech some descendant from some ancestor.
106    pub fn without_prefix(&self, prefix: &Self) -> Option<DerivationPath> {
107        if !self.starts_with(prefix) {
108            None
109        } else {
110            Some(self.0[prefix.len()..].to_vec().into())
111        }
112    }
113
114    /// Convenience function for finding the last hardened derivation in a path.
115    /// Returns the index and the element. If there is no hardened derivation, it
116    /// will return (0, None).
117    pub fn last_hardened(&self) -> (usize, Option<u32>) {
118        match self.iter().rev().position(|v| *v >= BIP32_HARDEN) {
119            Some(rev_pos) => {
120                let pos = self.len() - rev_pos - 1;
121                (pos, Some(self.0[pos]))
122            }
123            None => (0, None),
124        }
125    }
126
127    /// Return a clone with a resized path. If the new size is shorter, this truncates it. If the
128    /// new path is longer, we pad with the second argument.
129    pub fn resized(&self, size: usize, pad_with: u32) -> Self {
130        let mut child = self.clone();
131        child.0.resize(size, pad_with);
132        child
133    }
134
135    /// Append an additional derivation to the end, return a clone
136    pub fn extended(&self, idx: u32) -> Self {
137        let mut child = self.clone();
138        child.0.push(idx);
139        child
140    }
141}
142
143impl From<&DerivationPath> for DerivationPath {
144    fn from(v: &DerivationPath) -> Self {
145        v.clone()
146    }
147}
148
149impl From<Vec<u32>> for DerivationPath {
150    fn from(v: Vec<u32>) -> Self {
151        Self(v)
152    }
153}
154
155impl From<&Vec<u32>> for DerivationPath {
156    fn from(v: &Vec<u32>) -> Self {
157        Self(v.clone())
158    }
159}
160
161impl From<&[u32]> for DerivationPath {
162    fn from(v: &[u32]) -> Self {
163        Self(Vec::from(v))
164    }
165}
166
167impl TryFrom<u32> for DerivationPath {
168    type Error = Bip32Error;
169
170    fn try_from(v: u32) -> Result<Self, Self::Error> {
171        Ok(Self(vec![v]))
172    }
173}
174
175impl TryFrom<&str> for DerivationPath {
176    type Error = Bip32Error;
177
178    fn try_from(v: &str) -> Result<Self, Self::Error> {
179        v.parse()
180    }
181}
182
183impl FromIterator<u32> for DerivationPath {
184    fn from_iter<T>(iter: T) -> Self
185    where
186        T: IntoIterator<Item = u32>,
187    {
188        Vec::from_iter(iter).into()
189    }
190}
191
192impl FromStr for DerivationPath {
193    type Err = Bip32Error;
194
195    fn from_str(s: &str) -> Result<Self, Self::Err> {
196        s.split('/')
197            .filter(|v| v != &"m")
198            .map(try_parse_index)
199            .collect::<Result<Vec<u32>, Bip32Error>>()
200            .map(|v| v.into())
201            .map_err(|_| Bip32Error::MalformattedDerivation(s.to_owned()))
202    }
203}
204
205/// A Derivation Path for a bip32 key
206#[derive(Debug, Clone, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
207pub struct KeyDerivation {
208    /// The root key fingerprint
209    pub root: KeyFingerprint,
210    /// The derivation path from the root key
211    pub path: DerivationPath,
212}
213
214impl KeyDerivation {
215    /// `true` if the keys share a root fingerprint, `false` otherwise. Note that on key
216    /// fingerprints, which may collide accidentally, or be intentionally collided.
217    pub fn same_root(&self, other: &Self) -> bool {
218        self.root == other.root
219    }
220
221    /// `true` if this key is an ancestor of other, `false` otherwise. Note that on key
222    /// fingerprints, which may collide accidentally, or be intentionally collided.
223    pub fn is_possible_ancestor_of(&self, other: &Self) -> bool {
224        self.same_root(other) && other.path.starts_with(&self.path)
225    }
226
227    /// Returns the path to the decendant.
228    pub fn path_to_descendant(&self, descendant: &Self) -> Option<DerivationPath> {
229        descendant.path.without_prefix(&self.path)
230    }
231
232    /// Return a clone with a resized path. If the new size is shorter, this truncates it. If the
233    /// new path is longer, we pad with the second argument.
234    pub fn resized(&self, size: usize, pad_with: u32) -> Self {
235        Self {
236            root: self.root,
237            path: self.path.resized(size, pad_with),
238        }
239    }
240
241    /// Append an additional derivation to the end, return a clone
242    pub fn extended(&self, idx: u32) -> Self {
243        Self {
244            root: self.root,
245            path: self.path.extended(idx),
246        }
247    }
248}
249
250impl ByteFormat for KeyDerivation {
251    type Error = Bip32Error;
252
253    fn serialized_length(&self) -> usize {
254        4 + 4 * self.path.len()
255    }
256
257    fn read_from<T>(_reader: &mut T) -> Result<Self, Self::Error>
258    where
259        T: Read,
260        Self: std::marker::Sized,
261    {
262        unimplemented!()
263        // if limit == 0 {
264        //     return Err(SerError::RequiresLimit.into());
265        // }
266
267        // if limit > 255 {
268        //     return Err(Bip32Error::InvalidBip32Path);
269        // }
270
271        // let mut finger = [0u8; 4];
272        // reader.read_exact(&mut finger)?;
273
274        // let mut path = vec![];
275        // for _ in 0..limit {
276        //     let mut buf = [0u8; 4];
277        //     reader.read_exact(&mut buf)?;
278        //     path.push(u32::from_le_bytes(buf));
279        // }
280
281        // Ok(KeyDerivation {
282        //     root: finger.into(),
283        //     path: path.into(),
284        // })
285    }
286
287    fn write_to<T>(&self, writer: &mut T) -> Result<usize, Self::Error>
288    where
289        T: Write,
290    {
291        let mut length = writer.write(&self.root.0)?;
292        for i in self.path.iter() {
293            length += writer.write(&i.to_le_bytes())?;
294        }
295        Ok(length)
296    }
297}
298
299#[cfg(test)]
300pub mod test {
301    use super::*;
302
303    #[test]
304    fn it_parses_index_strings() {
305        let cases = [("32", 32), ("32h", 32 + BIP32_HARDEN), ("0h", BIP32_HARDEN)];
306        for case in cases.iter() {
307            match try_parse_index(case.0) {
308                Ok(v) => assert_eq!(v, case.1),
309                Err(e) => panic!("unexpected error {}", e),
310            }
311        }
312    }
313
314    #[test]
315    fn it_handles_malformatted_indices() {
316        let cases = ["-", "h", "toast", "憂鬱"];
317        for case in cases.iter() {
318            match try_parse_index(case) {
319                Ok(_) => panic!("expected an error"),
320                Err(Bip32Error::MalformattedDerivation(e)) => assert_eq!(&e, case),
321                Err(e) => panic!("unexpected error {}", e),
322            }
323        }
324    }
325
326    #[test]
327    fn it_parses_derivation_strings() {
328        let cases = [
329            ("m/32", vec![32]),
330            ("m/32'", vec![32 + BIP32_HARDEN]),
331            ("m/0'/32/5/5/5", vec![BIP32_HARDEN, 32, 5, 5, 5]),
332            ("32", vec![32]),
333            ("32'", vec![32 + BIP32_HARDEN]),
334            ("0'/32/5/5/5", vec![BIP32_HARDEN, 32, 5, 5, 5]),
335        ];
336        for case in cases.iter() {
337            match case.0.parse::<DerivationPath>() {
338                Ok(v) => assert_eq!(v.0, case.1),
339                Err(e) => panic!("unexpected error {}", e),
340            }
341        }
342    }
343
344    #[test]
345    fn it_handles_malformatted_derivations() {
346        let cases = ["//", "m/", "-", "h", "toast", "憂鬱"];
347        for case in cases.iter() {
348            match case.parse::<DerivationPath>() {
349                Ok(_) => panic!("expected an error"),
350                Err(Bip32Error::MalformattedDerivation(e)) => assert_eq!(&e, case),
351                Err(e) => panic!("unexpected error {}", e),
352            }
353        }
354    }
355
356    #[test]
357    fn it_removes_prefixes_from_derivations() {
358        // express each row in a separate instantiation syntax :)
359        let cases = [
360            (
361                DerivationPath(vec![1, 2, 3]),
362                DerivationPath(vec![1]),
363                Some(DerivationPath(vec![2, 3])),
364            ),
365            (
366                vec![1, 2, 3].into(),
367                vec![1, 2].into(),
368                Some(vec![3].into()),
369            ),
370            (
371                (1u32..=3).collect(),
372                (1u32..=3).collect(),
373                Some((0..0).collect()),
374            ),
375            (DerivationPath(vec![1, 2, 3]), vec![1, 3].into(), None),
376        ];
377        for case in cases.iter() {
378            assert_eq!(case.0.without_prefix(&case.1), case.2);
379        }
380    }
381
382    #[test]
383    fn it_proudces_paths_from_strings() {
384        let cases = ["//", "m/", "-", "h", "toast", "憂鬱"];
385
386        for case in cases.iter() {
387            let path: Result<DerivationPath, _> = case.parse().map_err(Into::into);
388            match path {
389                Ok(_) => panic!("expected an error"),
390                Err(Bip32Error::MalformattedDerivation(e)) => assert_eq!(&e, case),
391                Err(e) => panic!("unexpected error {}", e),
392            }
393        }
394    }
395
396    #[test]
397    fn it_stringifies_derivation_paths() {
398        let cases = [
399            (DerivationPath(vec![1, 2, 3]), "m/1/2/3"),
400            (
401                vec![BIP32_HARDEN, BIP32_HARDEN, BIP32_HARDEN].into(),
402                "m/0'/0'/0'",
403            ),
404        ];
405        for case in cases.iter() {
406            assert_eq!(&case.0.derivation_string(), case.1);
407        }
408    }
409}