fedimint_core/util/
mod.rs

1pub mod backoff_util;
2/// Copied from `tokio_stream` 0.1.12 to use our optional Send bounds
3pub mod broadcaststream;
4pub mod update_merge;
5
6use std::convert::Infallible;
7use std::fmt::{Debug, Display, Formatter};
8use std::future::Future;
9use std::hash::Hash;
10use std::io::Write;
11use std::path::Path;
12use std::pin::Pin;
13use std::str::FromStr;
14use std::{fs, io};
15
16use anyhow::format_err;
17use fedimint_logging::LOG_CORE;
18use futures::StreamExt;
19use serde::Serialize;
20use thiserror::Error;
21use tokio::io::AsyncWriteExt;
22use tracing::{debug, warn, Instrument, Span};
23use url::{Host, ParseError, Url};
24
25use crate::net::STANDARD_FEDIMINT_P2P_PORT;
26use crate::task::MaybeSend;
27use crate::{apply, async_trait_maybe_send, maybe_add_send, runtime};
28
29/// Future that is `Send` unless targeting WASM
30pub type BoxFuture<'a, T> = Pin<Box<maybe_add_send!(dyn Future<Output = T> + 'a)>>;
31
32/// Stream that is `Send` unless targeting WASM
33pub type BoxStream<'a, T> = Pin<Box<maybe_add_send!(dyn futures::Stream<Item = T> + 'a)>>;
34
35#[apply(async_trait_maybe_send!)]
36pub trait NextOrPending {
37    type Output;
38
39    async fn next_or_pending(&mut self) -> Self::Output;
40
41    async fn ok(&mut self) -> anyhow::Result<Self::Output>;
42}
43
44#[apply(async_trait_maybe_send!)]
45impl<S> NextOrPending for S
46where
47    S: futures::Stream + Unpin + MaybeSend,
48    S::Item: MaybeSend,
49{
50    type Output = S::Item;
51
52    /// Waits for the next item in a stream. If the stream is closed while
53    /// waiting, returns an error.  Useful when expecting a stream to progress.
54    async fn ok(&mut self) -> anyhow::Result<Self::Output> {
55        self.next()
56            .await
57            .map_or_else(|| Err(format_err!("Stream was unexpectedly closed")), Ok)
58    }
59
60    /// Waits for the next item in a stream. If the stream is closed while
61    /// waiting the future will be pending forever. This is useful in cases
62    /// where the future will be cancelled by shutdown logic anyway and handling
63    /// each place where a stream may terminate would be too much trouble.
64    async fn next_or_pending(&mut self) -> Self::Output {
65        if let Some(item) = self.next().await {
66            item
67        } else {
68            debug!("Stream ended in next_or_pending, pending forever to avoid throwing an error on shutdown");
69            std::future::pending().await
70        }
71    }
72}
73
74// TODO: make fully RFC1738 conformant
75/// Wrapper for `Url` that only prints the scheme, domain, port and path portion
76/// of a `Url` in its `Display` implementation.
77///
78/// This is useful to hide private
79/// information like user names and passwords in logs or UIs.
80///
81/// The output is not fully RFC1738 conformant but good enough for our current
82/// purposes.
83#[derive(Hash, Clone, Serialize, Eq, PartialEq, Ord, PartialOrd)]
84// nosemgrep: ban-raw-url
85pub struct SafeUrl(Url);
86
87#[derive(Debug, Error)]
88pub enum SafeUrlError {
89    #[error("Failed to remove auth from URL")]
90    WithoutAuthError,
91}
92
93impl SafeUrl {
94    pub fn parse(url_str: &str) -> Result<Self, ParseError> {
95        let s = Url::parse(url_str).map(SafeUrl)?;
96
97        if s.port_or_known_default().is_none() {
98            return Err(ParseError::InvalidPort);
99        }
100        Ok(s)
101    }
102
103    /// Warning: This removes the safety.
104    // nosemgrep: ban-raw-url
105    pub fn to_unsafe(self) -> Url {
106        self.0
107    }
108
109    #[allow(clippy::result_unit_err)] // just copying `url`'s API here
110    pub fn set_username(&mut self, username: &str) -> Result<(), ()> {
111        self.0.set_username(username)
112    }
113
114    #[allow(clippy::result_unit_err)] // just copying `url`'s API here
115    pub fn set_password(&mut self, password: Option<&str>) -> Result<(), ()> {
116        self.0.set_password(password)
117    }
118
119    pub fn without_auth(&self) -> Result<Self, SafeUrlError> {
120        let mut url = self.clone();
121        url.set_username("")
122            .and_then(|()| url.set_password(None))
123            .map_err(|()| SafeUrlError::WithoutAuthError)?;
124        Ok(url)
125    }
126
127    pub fn host(&self) -> Option<Host<&str>> {
128        self.0.host()
129    }
130    pub fn host_str(&self) -> Option<&str> {
131        self.0.host_str()
132    }
133    pub fn scheme(&self) -> &str {
134        self.0.scheme()
135    }
136    pub fn port(&self) -> Option<u16> {
137        self.0.port()
138    }
139    pub fn port_or_known_default(&self) -> Option<u16> {
140        if let Some(port) = self.port() {
141            return Some(port);
142        }
143        match self.0.scheme() {
144            // p2p port scheme
145            "fedimint" => Some(STANDARD_FEDIMINT_P2P_PORT),
146            _ => self.0.port_or_known_default(),
147        }
148    }
149
150    /// `self` but with port explicitly set, if known from url
151    pub fn with_port_or_known_default(&self) -> Self {
152        if self.port().is_none() {
153            if let Some(default) = self.port_or_known_default() {
154                let mut url = self.clone();
155                url.0.set_port(Some(default)).expect("Can't fail");
156                return url;
157            }
158        }
159
160        self.clone()
161    }
162
163    pub fn path(&self) -> &str {
164        self.0.path()
165    }
166    /// Warning: This will expose username & password if present.
167    pub fn as_str(&self) -> &str {
168        self.0.as_str()
169    }
170    pub fn username(&self) -> &str {
171        self.0.username()
172    }
173    pub fn password(&self) -> Option<&str> {
174        self.0.password()
175    }
176    pub fn join(&self, input: &str) -> Result<Self, ParseError> {
177        self.0.join(input).map(SafeUrl)
178    }
179
180    // It can be removed to use `is_onion_address()` implementation,
181    // once https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/2214 lands.
182    #[allow(clippy::case_sensitive_file_extension_comparisons)]
183    pub fn is_onion_address(&self) -> bool {
184        let host = self.host_str().unwrap_or_default();
185
186        host.ends_with(".onion")
187    }
188}
189
190impl Display for SafeUrl {
191    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
192        write!(f, "{}://", self.0.scheme())?;
193
194        if !self.0.username().is_empty() {
195            write!(f, "REDACTEDUSER")?;
196
197            if self.0.password().is_some() {
198                write!(f, ":REDACTEDPASS")?;
199            }
200
201            write!(f, "@")?;
202        }
203
204        if let Some(host) = self.0.host_str() {
205            write!(f, "{host}")?;
206        }
207
208        if let Some(port) = self.0.port() {
209            write!(f, ":{port}")?;
210        }
211
212        write!(f, "{}", self.0.path())?;
213
214        Ok(())
215    }
216}
217
218impl<'de> serde::de::Deserialize<'de> for SafeUrl {
219    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
220    where
221        D: serde::Deserializer<'de>,
222    {
223        let s = Self(Url::deserialize(deserializer)?);
224
225        if s.port_or_known_default().is_none() {
226            return Err(serde::de::Error::custom("Invalid port"));
227        }
228
229        Ok(s)
230    }
231}
232impl Debug for SafeUrl {
233    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
234        write!(f, "SafeUrl(")?;
235        Display::fmt(self, f)?;
236        write!(f, ")")?;
237        Ok(())
238    }
239}
240
241/// Only ease conversions from unsafe into safe version.
242/// We want to protect leakage of sensitive credentials unless code explicitly
243/// calls `to_unsafe()`.
244impl TryFrom<Url> for SafeUrl {
245    type Error = anyhow::Error;
246    fn try_from(u: Url) -> anyhow::Result<Self> {
247        let s = Self(u);
248
249        if s.port_or_known_default().is_none() {
250            anyhow::bail!("Invalid port");
251        }
252
253        Ok(s)
254    }
255}
256
257impl FromStr for SafeUrl {
258    type Err = ParseError;
259
260    #[inline]
261    fn from_str(input: &str) -> Result<Self, ParseError> {
262        Self::parse(input)
263    }
264}
265
266/// Write out a new file (like [`std::fs::write`] but fails if file already
267/// exists)
268#[cfg(not(target_family = "wasm"))]
269pub fn write_new<P: AsRef<Path>, C: AsRef<[u8]>>(path: P, contents: C) -> io::Result<()> {
270    fs::File::options()
271        .write(true)
272        .create_new(true)
273        .open(path)?
274        .write_all(contents.as_ref())
275}
276
277#[cfg(not(target_family = "wasm"))]
278pub fn write_overwrite<P: AsRef<Path>, C: AsRef<[u8]>>(path: P, contents: C) -> io::Result<()> {
279    fs::File::options()
280        .write(true)
281        .create(true)
282        .truncate(true)
283        .open(path)?
284        .write_all(contents.as_ref())
285}
286
287#[cfg(not(target_family = "wasm"))]
288pub async fn write_overwrite_async<P: AsRef<Path>, C: AsRef<[u8]>>(
289    path: P,
290    contents: C,
291) -> io::Result<()> {
292    tokio::fs::OpenOptions::new()
293        .write(true)
294        .create(true)
295        .truncate(true)
296        .open(path)
297        .await?
298        .write_all(contents.as_ref())
299        .await
300}
301
302#[cfg(not(target_family = "wasm"))]
303pub async fn write_new_async<P: AsRef<Path>, C: AsRef<[u8]>>(
304    path: P,
305    contents: C,
306) -> io::Result<()> {
307    tokio::fs::OpenOptions::new()
308        .write(true)
309        .create_new(true)
310        .open(path)
311        .await?
312        .write_all(contents.as_ref())
313        .await
314}
315
316#[derive(Debug, Clone)]
317pub struct Spanned<T> {
318    value: T,
319    span: Span,
320}
321
322impl<T> Spanned<T> {
323    pub async fn new<F: Future<Output = T>>(span: Span, make: F) -> Self {
324        Self::try_new::<Infallible, _>(span, async { Ok(make.await) })
325            .await
326            .unwrap()
327    }
328
329    pub async fn try_new<E, F: Future<Output = Result<T, E>>>(
330        span: Span,
331        make: F,
332    ) -> Result<Self, E> {
333        let span2 = span.clone();
334        async {
335            Ok(Self {
336                value: make.await?,
337                span: span2,
338            })
339        }
340        .instrument(span)
341        .await
342    }
343
344    pub fn borrow(&self) -> Spanned<&T> {
345        Spanned {
346            value: &self.value,
347            span: self.span.clone(),
348        }
349    }
350
351    pub fn map<U>(self, map: impl Fn(T) -> U) -> Spanned<U> {
352        Spanned {
353            value: map(self.value),
354            span: self.span,
355        }
356    }
357
358    pub fn borrow_mut(&mut self) -> Spanned<&mut T> {
359        Spanned {
360            value: &mut self.value,
361            span: self.span.clone(),
362        }
363    }
364
365    pub fn with_sync<O, F: FnOnce(T) -> O>(self, f: F) -> O {
366        let _g = self.span.enter();
367        f(self.value)
368    }
369
370    pub async fn with<Fut: Future, F: FnOnce(T) -> Fut>(self, f: F) -> Fut::Output {
371        async { f(self.value).await }.instrument(self.span).await
372    }
373
374    pub fn span(&self) -> Span {
375        self.span.clone()
376    }
377
378    pub fn value(&self) -> &T {
379        &self.value
380    }
381
382    pub fn value_mut(&mut self) -> &mut T {
383        &mut self.value
384    }
385
386    pub fn into_value(self) -> T {
387        self.value
388    }
389}
390
391/// For CLIs, detects `version-hash` as a single argument, prints the provided
392/// version hash, then exits the process.
393pub fn handle_version_hash_command(version_hash: &str) {
394    let mut args = std::env::args();
395    if let Some(ref arg) = args.nth(1) {
396        if arg.as_str() == "version-hash" {
397            println!("{version_hash}");
398            std::process::exit(0);
399        }
400    }
401}
402
403/// Run the supplied closure `op_fn` until it succeeds. Frequency and number of
404/// retries is determined by the specified strategy.
405///
406/// ```
407/// use std::time::Duration;
408///
409/// use fedimint_core::util::{backoff_util, retry};
410/// # tokio_test::block_on(async {
411/// retry(
412///     "Gateway balance after swap".to_string(),
413///     backoff_util::background_backoff(),
414///     || async {
415///         // Fallible network calls …
416///         Ok(())
417///     },
418/// )
419/// .await
420/// .expect("never fails");
421/// # });
422/// ```
423///
424/// # Returns
425///
426/// - If the closure runs successfully, the result is immediately returned
427/// - If the closure did not run successfully for `max_attempts` times, the
428///   error of the closure is returned
429pub async fn retry<F, Fut, T>(
430    op_name: impl Into<String>,
431    strategy: impl backoff_util::Backoff,
432    op_fn: F,
433) -> Result<T, anyhow::Error>
434where
435    F: Fn() -> Fut,
436    Fut: Future<Output = Result<T, anyhow::Error>>,
437{
438    let mut strategy = strategy;
439    let op_name = op_name.into();
440    let mut attempts: u64 = 0;
441    loop {
442        attempts += 1;
443        match op_fn().await {
444            Ok(result) => return Ok(result),
445            Err(error) => {
446                if let Some(interval) = strategy.next() {
447                    // run closure op_fn again
448                    debug!(
449                        target: LOG_CORE,
450                        %error,
451                        %attempts,
452                        interval = interval.as_secs(),
453                        "{} failed, retrying",
454                        op_name,
455                    );
456                    runtime::sleep(interval).await;
457                } else {
458                    warn!(
459                        target: LOG_CORE,
460                        ?error,
461                        %attempts,
462                        "{} failed",
463                        op_name,
464                    );
465                    return Err(error);
466                }
467            }
468        }
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    use std::sync::atomic::{AtomicU8, Ordering};
475    use std::time::Duration;
476
477    use anyhow::anyhow;
478    use fedimint_core::runtime::Elapsed;
479    use futures::FutureExt;
480
481    use super::*;
482    use crate::runtime::timeout;
483
484    #[test]
485    fn test_safe_url() {
486        let test_cases = vec![
487            (
488                "http://1.2.3.4:80/foo",
489                "http://1.2.3.4/foo",
490                "SafeUrl(http://1.2.3.4/foo)",
491                "http://1.2.3.4/foo",
492            ),
493            (
494                "http://1.2.3.4:81/foo",
495                "http://1.2.3.4:81/foo",
496                "SafeUrl(http://1.2.3.4:81/foo)",
497                "http://1.2.3.4:81/foo",
498            ),
499            (
500                "fedimint://1.2.3.4:1000/foo",
501                "fedimint://1.2.3.4:1000/foo",
502                "SafeUrl(fedimint://1.2.3.4:1000/foo)",
503                "fedimint://1.2.3.4:1000/foo",
504            ),
505            (
506                "fedimint://foo:bar@domain.com:1000/foo",
507                "fedimint://REDACTEDUSER:REDACTEDPASS@domain.com:1000/foo",
508                "SafeUrl(fedimint://REDACTEDUSER:REDACTEDPASS@domain.com:1000/foo)",
509                "fedimint://domain.com:1000/foo",
510            ),
511            (
512                "fedimint://foo@1.2.3.4:1000/foo",
513                "fedimint://REDACTEDUSER@1.2.3.4:1000/foo",
514                "SafeUrl(fedimint://REDACTEDUSER@1.2.3.4:1000/foo)",
515                "fedimint://1.2.3.4:1000/foo",
516            ),
517        ];
518
519        for (url_str, safe_display_expected, safe_debug_expected, without_auth_expected) in
520            test_cases
521        {
522            let safe_url = SafeUrl::parse(url_str).unwrap();
523
524            let safe_display = format!("{safe_url}");
525            assert_eq!(
526                safe_display, safe_display_expected,
527                "Display implementation out of spec"
528            );
529
530            let safe_debug = format!("{safe_url:?}");
531            assert_eq!(
532                safe_debug, safe_debug_expected,
533                "Debug implementation out of spec"
534            );
535
536            let without_auth = safe_url.without_auth().unwrap();
537            assert_eq!(
538                without_auth.as_str(),
539                without_auth_expected,
540                "Without auth implementation out of spec"
541            );
542        }
543
544        // Exercise `From`-trait via `Into`
545        let _: SafeUrl = url::Url::parse("http://1.2.3.4:80/foo")
546            .unwrap()
547            .try_into()
548            .unwrap();
549    }
550
551    #[tokio::test]
552    async fn test_next_or_pending() {
553        let mut stream = futures::stream::iter(vec![1, 2]);
554        assert_eq!(stream.next_or_pending().now_or_never(), Some(1));
555        assert_eq!(stream.next_or_pending().now_or_never(), Some(2));
556        assert!(matches!(
557            timeout(Duration::from_millis(100), stream.next_or_pending()).await,
558            Err(Elapsed)
559        ));
560    }
561
562    #[tokio::test]
563    async fn retry_succeed_with_one_attempt() {
564        let counter = AtomicU8::new(0);
565        let closure = || async {
566            counter.fetch_add(1, Ordering::SeqCst);
567            // Always return a success.
568            Ok(42)
569        };
570
571        let _ = retry(
572            "Run once",
573            backoff_util::immediate_backoff(Some(2)),
574            closure,
575        )
576        .await;
577
578        // Ensure the closure was only called once, and no backoff was applied.
579        assert_eq!(counter.load(Ordering::SeqCst), 1);
580    }
581
582    #[tokio::test]
583    async fn retry_fail_with_three_attempts() {
584        let counter = AtomicU8::new(0);
585        let closure = || async {
586            counter.fetch_add(1, Ordering::SeqCst);
587            // always fail
588            Err::<(), anyhow::Error>(anyhow!("42"))
589        };
590
591        let _ = retry(
592            "Run 3 times",
593            backoff_util::immediate_backoff(Some(2)),
594            closure,
595        )
596        .await;
597
598        assert_eq!(counter.load(Ordering::SeqCst), 3);
599    }
600}