fedimint_core/util/
mod.rs

1pub mod backoff_util;
2/// Copied from `tokio_stream` 0.1.12 to use our optional Send bounds
3pub mod broadcaststream;
4mod error;
5pub mod update_merge;
6
7use std::convert::Infallible;
8use std::fmt::{Debug, Display, Formatter};
9use std::future::Future;
10use std::hash::Hash;
11use std::io::Write;
12use std::path::Path;
13use std::pin::Pin;
14use std::str::FromStr;
15use std::{fs, io};
16
17use anyhow::format_err;
18pub use error::*;
19use fedimint_logging::LOG_CORE;
20use futures::StreamExt;
21use serde::Serialize;
22use thiserror::Error;
23use tokio::io::AsyncWriteExt;
24use tracing::{debug, warn, Instrument, Span};
25use url::{Host, ParseError, Url};
26
27use crate::net::STANDARD_FEDIMINT_P2P_PORT;
28use crate::task::MaybeSend;
29use crate::{apply, async_trait_maybe_send, maybe_add_send, runtime};
30
31/// Future that is `Send` unless targeting WASM
32pub type BoxFuture<'a, T> = Pin<Box<maybe_add_send!(dyn Future<Output = T> + 'a)>>;
33
34/// Stream that is `Send` unless targeting WASM
35pub type BoxStream<'a, T> = Pin<Box<maybe_add_send!(dyn futures::Stream<Item = T> + 'a)>>;
36
37#[apply(async_trait_maybe_send!)]
38pub trait NextOrPending {
39    type Output;
40
41    async fn next_or_pending(&mut self) -> Self::Output;
42
43    async fn ok(&mut self) -> anyhow::Result<Self::Output>;
44}
45
46#[apply(async_trait_maybe_send!)]
47impl<S> NextOrPending for S
48where
49    S: futures::Stream + Unpin + MaybeSend,
50    S::Item: MaybeSend,
51{
52    type Output = S::Item;
53
54    /// Waits for the next item in a stream. If the stream is closed while
55    /// waiting, returns an error.  Useful when expecting a stream to progress.
56    async fn ok(&mut self) -> anyhow::Result<Self::Output> {
57        self.next()
58            .await
59            .map_or_else(|| Err(format_err!("Stream was unexpectedly closed")), Ok)
60    }
61
62    /// Waits for the next item in a stream. If the stream is closed while
63    /// waiting the future will be pending forever. This is useful in cases
64    /// where the future will be cancelled by shutdown logic anyway and handling
65    /// each place where a stream may terminate would be too much trouble.
66    async fn next_or_pending(&mut self) -> Self::Output {
67        if let Some(item) = self.next().await {
68            item
69        } else {
70            debug!(target: LOG_CORE, "Stream ended in next_or_pending, pending forever to avoid throwing an error on shutdown");
71            std::future::pending().await
72        }
73    }
74}
75
76// TODO: make fully RFC1738 conformant
77/// Wrapper for `Url` that only prints the scheme, domain, port and path portion
78/// of a `Url` in its `Display` implementation.
79///
80/// This is useful to hide private
81/// information like user names and passwords in logs or UIs.
82///
83/// The output is not fully RFC1738 conformant but good enough for our current
84/// purposes.
85#[derive(Hash, Clone, Serialize, Eq, PartialEq, Ord, PartialOrd)]
86// nosemgrep: ban-raw-url
87pub struct SafeUrl(Url);
88
89#[derive(Debug, Error)]
90pub enum SafeUrlError {
91    #[error("Failed to remove auth from URL")]
92    WithoutAuthError,
93}
94
95impl SafeUrl {
96    pub fn parse(url_str: &str) -> Result<Self, ParseError> {
97        let s = Url::parse(url_str).map(SafeUrl)?;
98
99        if s.port_or_known_default().is_none() {
100            return Err(ParseError::InvalidPort);
101        }
102        Ok(s)
103    }
104
105    /// Warning: This removes the safety.
106    // nosemgrep: ban-raw-url
107    pub fn to_unsafe(self) -> Url {
108        self.0
109    }
110
111    #[allow(clippy::result_unit_err)] // just copying `url`'s API here
112    pub fn set_username(&mut self, username: &str) -> Result<(), ()> {
113        self.0.set_username(username)
114    }
115
116    #[allow(clippy::result_unit_err)] // just copying `url`'s API here
117    pub fn set_password(&mut self, password: Option<&str>) -> Result<(), ()> {
118        self.0.set_password(password)
119    }
120
121    pub fn without_auth(&self) -> Result<Self, SafeUrlError> {
122        let mut url = self.clone();
123        url.set_username("")
124            .and_then(|()| url.set_password(None))
125            .map_err(|()| SafeUrlError::WithoutAuthError)?;
126        Ok(url)
127    }
128
129    pub fn host(&self) -> Option<Host<&str>> {
130        self.0.host()
131    }
132    pub fn host_str(&self) -> Option<&str> {
133        self.0.host_str()
134    }
135    pub fn scheme(&self) -> &str {
136        self.0.scheme()
137    }
138    pub fn port(&self) -> Option<u16> {
139        self.0.port()
140    }
141    pub fn port_or_known_default(&self) -> Option<u16> {
142        if let Some(port) = self.port() {
143            return Some(port);
144        }
145        match self.0.scheme() {
146            // p2p port scheme
147            "fedimint" => Some(STANDARD_FEDIMINT_P2P_PORT),
148            _ => self.0.port_or_known_default(),
149        }
150    }
151
152    /// `self` but with port explicitly set, if known from url
153    pub fn with_port_or_known_default(&self) -> Self {
154        if self.port().is_none() {
155            if let Some(default) = self.port_or_known_default() {
156                let mut url = self.clone();
157                url.0.set_port(Some(default)).expect("Can't fail");
158                return url;
159            }
160        }
161
162        self.clone()
163    }
164
165    pub fn path(&self) -> &str {
166        self.0.path()
167    }
168    /// Warning: This will expose username & password if present.
169    pub fn as_str(&self) -> &str {
170        self.0.as_str()
171    }
172    pub fn username(&self) -> &str {
173        self.0.username()
174    }
175    pub fn password(&self) -> Option<&str> {
176        self.0.password()
177    }
178    pub fn join(&self, input: &str) -> Result<Self, ParseError> {
179        self.0.join(input).map(SafeUrl)
180    }
181
182    // It can be removed to use `is_onion_address()` implementation,
183    // once https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/2214 lands.
184    #[allow(clippy::case_sensitive_file_extension_comparisons)]
185    pub fn is_onion_address(&self) -> bool {
186        let host = self.host_str().unwrap_or_default();
187
188        host.ends_with(".onion")
189    }
190
191    pub fn fragment(&self) -> Option<&str> {
192        self.0.fragment()
193    }
194
195    pub fn set_fragment(&mut self, arg: Option<&str>) {
196        self.0.set_fragment(arg);
197    }
198}
199
200impl Display for SafeUrl {
201    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
202        write!(f, "{}://", self.0.scheme())?;
203
204        if !self.0.username().is_empty() {
205            write!(f, "REDACTEDUSER")?;
206
207            if self.0.password().is_some() {
208                write!(f, ":REDACTEDPASS")?;
209            }
210
211            write!(f, "@")?;
212        }
213
214        if let Some(host) = self.0.host_str() {
215            write!(f, "{host}")?;
216        }
217
218        if let Some(port) = self.0.port() {
219            write!(f, ":{port}")?;
220        }
221
222        write!(f, "{}", self.0.path())?;
223
224        Ok(())
225    }
226}
227
228impl<'de> serde::de::Deserialize<'de> for SafeUrl {
229    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
230    where
231        D: serde::Deserializer<'de>,
232    {
233        let s = Self(Url::deserialize(deserializer)?);
234
235        if s.port_or_known_default().is_none() {
236            return Err(serde::de::Error::custom("Invalid port"));
237        }
238
239        Ok(s)
240    }
241}
242impl Debug for SafeUrl {
243    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
244        write!(f, "SafeUrl(")?;
245        Display::fmt(self, f)?;
246        write!(f, ")")?;
247        Ok(())
248    }
249}
250
251/// Only ease conversions from unsafe into safe version.
252/// We want to protect leakage of sensitive credentials unless code explicitly
253/// calls `to_unsafe()`.
254impl TryFrom<Url> for SafeUrl {
255    type Error = anyhow::Error;
256    fn try_from(u: Url) -> anyhow::Result<Self> {
257        let s = Self(u);
258
259        if s.port_or_known_default().is_none() {
260            anyhow::bail!("Invalid port");
261        }
262
263        Ok(s)
264    }
265}
266
267impl FromStr for SafeUrl {
268    type Err = ParseError;
269
270    #[inline]
271    fn from_str(input: &str) -> Result<Self, ParseError> {
272        Self::parse(input)
273    }
274}
275
276/// Write out a new file (like [`std::fs::write`] but fails if file already
277/// exists)
278#[cfg(not(target_family = "wasm"))]
279pub fn write_new<P: AsRef<Path>, C: AsRef<[u8]>>(path: P, contents: C) -> io::Result<()> {
280    fs::File::options()
281        .write(true)
282        .create_new(true)
283        .open(path)?
284        .write_all(contents.as_ref())
285}
286
287#[cfg(not(target_family = "wasm"))]
288pub fn write_overwrite<P: AsRef<Path>, C: AsRef<[u8]>>(path: P, contents: C) -> io::Result<()> {
289    fs::File::options()
290        .write(true)
291        .create(true)
292        .truncate(true)
293        .open(path)?
294        .write_all(contents.as_ref())
295}
296
297#[cfg(not(target_family = "wasm"))]
298pub async fn write_overwrite_async<P: AsRef<Path>, C: AsRef<[u8]>>(
299    path: P,
300    contents: C,
301) -> io::Result<()> {
302    tokio::fs::OpenOptions::new()
303        .write(true)
304        .create(true)
305        .truncate(true)
306        .open(path)
307        .await?
308        .write_all(contents.as_ref())
309        .await
310}
311
312#[cfg(not(target_family = "wasm"))]
313pub async fn write_new_async<P: AsRef<Path>, C: AsRef<[u8]>>(
314    path: P,
315    contents: C,
316) -> io::Result<()> {
317    tokio::fs::OpenOptions::new()
318        .write(true)
319        .create_new(true)
320        .open(path)
321        .await?
322        .write_all(contents.as_ref())
323        .await
324}
325
326#[derive(Debug, Clone)]
327pub struct Spanned<T> {
328    value: T,
329    span: Span,
330}
331
332impl<T> Spanned<T> {
333    pub async fn new<F: Future<Output = T>>(span: Span, make: F) -> Self {
334        Self::try_new::<Infallible, _>(span, async { Ok(make.await) })
335            .await
336            .unwrap()
337    }
338
339    pub async fn try_new<E, F: Future<Output = Result<T, E>>>(
340        span: Span,
341        make: F,
342    ) -> Result<Self, E> {
343        let span2 = span.clone();
344        async {
345            Ok(Self {
346                value: make.await?,
347                span: span2,
348            })
349        }
350        .instrument(span)
351        .await
352    }
353
354    pub fn borrow(&self) -> Spanned<&T> {
355        Spanned {
356            value: &self.value,
357            span: self.span.clone(),
358        }
359    }
360
361    pub fn map<U>(self, map: impl Fn(T) -> U) -> Spanned<U> {
362        Spanned {
363            value: map(self.value),
364            span: self.span,
365        }
366    }
367
368    pub fn borrow_mut(&mut self) -> Spanned<&mut T> {
369        Spanned {
370            value: &mut self.value,
371            span: self.span.clone(),
372        }
373    }
374
375    pub fn with_sync<O, F: FnOnce(T) -> O>(self, f: F) -> O {
376        let _g = self.span.enter();
377        f(self.value)
378    }
379
380    pub async fn with<Fut: Future, F: FnOnce(T) -> Fut>(self, f: F) -> Fut::Output {
381        async { f(self.value).await }.instrument(self.span).await
382    }
383
384    pub fn span(&self) -> Span {
385        self.span.clone()
386    }
387
388    pub fn value(&self) -> &T {
389        &self.value
390    }
391
392    pub fn value_mut(&mut self) -> &mut T {
393        &mut self.value
394    }
395
396    pub fn into_value(self) -> T {
397        self.value
398    }
399}
400
401/// For CLIs, detects `version-hash` as a single argument, prints the provided
402/// version hash, then exits the process.
403pub fn handle_version_hash_command(version_hash: &str) {
404    let mut args = std::env::args();
405    if let Some(ref arg) = args.nth(1) {
406        if arg.as_str() == "version-hash" {
407            println!("{version_hash}");
408            std::process::exit(0);
409        }
410    }
411}
412
413/// Run the supplied closure `op_fn` until it succeeds. Frequency and number of
414/// retries is determined by the specified strategy.
415///
416/// ```
417/// use std::time::Duration;
418///
419/// use fedimint_core::util::{backoff_util, retry};
420/// # tokio_test::block_on(async {
421/// retry(
422///     "Gateway balance after swap".to_string(),
423///     backoff_util::background_backoff(),
424///     || async {
425///         // Fallible network calls …
426///         Ok(())
427///     },
428/// )
429/// .await
430/// .expect("never fails");
431/// # });
432/// ```
433///
434/// # Returns
435///
436/// - If the closure runs successfully, the result is immediately returned
437/// - If the closure did not run successfully for `max_attempts` times, the
438///   error of the closure is returned
439pub async fn retry<F, Fut, T>(
440    op_name: impl Into<String>,
441    strategy: impl backoff_util::Backoff,
442    op_fn: F,
443) -> Result<T, anyhow::Error>
444where
445    F: Fn() -> Fut,
446    Fut: Future<Output = Result<T, anyhow::Error>>,
447{
448    let mut strategy = strategy;
449    let op_name = op_name.into();
450    let mut attempts: u64 = 0;
451    loop {
452        attempts += 1;
453        match op_fn().await {
454            Ok(result) => return Ok(result),
455            Err(err) => {
456                if let Some(interval) = strategy.next() {
457                    // run closure op_fn again
458                    debug!(
459                        target: LOG_CORE,
460                        err = %err.fmt_compact_anyhow(),
461                        %attempts,
462                        interval = interval.as_secs(),
463                        "{} failed, retrying",
464                        op_name,
465                    );
466                    runtime::sleep(interval).await;
467                } else {
468                    warn!(
469                        target: LOG_CORE,
470                        err = %err.fmt_compact_anyhow(),
471                        %attempts,
472                        "{} failed",
473                        op_name,
474                    );
475                    return Err(err);
476                }
477            }
478        }
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use std::sync::atomic::{AtomicU8, Ordering};
485    use std::time::Duration;
486
487    use anyhow::anyhow;
488    use fedimint_core::runtime::Elapsed;
489    use futures::FutureExt;
490
491    use super::*;
492    use crate::runtime::timeout;
493
494    #[test]
495    fn test_safe_url() {
496        let test_cases = vec![
497            (
498                "http://1.2.3.4:80/foo",
499                "http://1.2.3.4/foo",
500                "SafeUrl(http://1.2.3.4/foo)",
501                "http://1.2.3.4/foo",
502            ),
503            (
504                "http://1.2.3.4:81/foo",
505                "http://1.2.3.4:81/foo",
506                "SafeUrl(http://1.2.3.4:81/foo)",
507                "http://1.2.3.4:81/foo",
508            ),
509            (
510                "fedimint://1.2.3.4:1000/foo",
511                "fedimint://1.2.3.4:1000/foo",
512                "SafeUrl(fedimint://1.2.3.4:1000/foo)",
513                "fedimint://1.2.3.4:1000/foo",
514            ),
515            (
516                "fedimint://foo:bar@domain.com:1000/foo",
517                "fedimint://REDACTEDUSER:REDACTEDPASS@domain.com:1000/foo",
518                "SafeUrl(fedimint://REDACTEDUSER:REDACTEDPASS@domain.com:1000/foo)",
519                "fedimint://domain.com:1000/foo",
520            ),
521            (
522                "fedimint://foo@1.2.3.4:1000/foo",
523                "fedimint://REDACTEDUSER@1.2.3.4:1000/foo",
524                "SafeUrl(fedimint://REDACTEDUSER@1.2.3.4:1000/foo)",
525                "fedimint://1.2.3.4:1000/foo",
526            ),
527        ];
528
529        for (url_str, safe_display_expected, safe_debug_expected, without_auth_expected) in
530            test_cases
531        {
532            let safe_url = SafeUrl::parse(url_str).unwrap();
533
534            let safe_display = format!("{safe_url}");
535            assert_eq!(
536                safe_display, safe_display_expected,
537                "Display implementation out of spec"
538            );
539
540            let safe_debug = format!("{safe_url:?}");
541            assert_eq!(
542                safe_debug, safe_debug_expected,
543                "Debug implementation out of spec"
544            );
545
546            let without_auth = safe_url.without_auth().unwrap();
547            assert_eq!(
548                without_auth.as_str(),
549                without_auth_expected,
550                "Without auth implementation out of spec"
551            );
552        }
553
554        // Exercise `From`-trait via `Into`
555        let _: SafeUrl = url::Url::parse("http://1.2.3.4:80/foo")
556            .unwrap()
557            .try_into()
558            .unwrap();
559    }
560
561    #[tokio::test]
562    async fn test_next_or_pending() {
563        let mut stream = futures::stream::iter(vec![1, 2]);
564        assert_eq!(stream.next_or_pending().now_or_never(), Some(1));
565        assert_eq!(stream.next_or_pending().now_or_never(), Some(2));
566        assert!(matches!(
567            timeout(Duration::from_millis(100), stream.next_or_pending()).await,
568            Err(Elapsed)
569        ));
570    }
571
572    #[tokio::test]
573    async fn retry_succeed_with_one_attempt() {
574        let counter = AtomicU8::new(0);
575        let closure = || async {
576            counter.fetch_add(1, Ordering::SeqCst);
577            // Always return a success.
578            Ok(42)
579        };
580
581        let _ = retry(
582            "Run once",
583            backoff_util::immediate_backoff(Some(2)),
584            closure,
585        )
586        .await;
587
588        // Ensure the closure was only called once, and no backoff was applied.
589        assert_eq!(counter.load(Ordering::SeqCst), 1);
590    }
591
592    #[tokio::test]
593    async fn retry_fail_with_three_attempts() {
594        let counter = AtomicU8::new(0);
595        let closure = || async {
596            counter.fetch_add(1, Ordering::SeqCst);
597            // always fail
598            Err::<(), anyhow::Error>(anyhow!("42"))
599        };
600
601        let _ = retry(
602            "Run 3 times",
603            backoff_util::immediate_backoff(Some(2)),
604            closure,
605        )
606        .await;
607
608        assert_eq!(counter.load(Ordering::SeqCst), 3);
609    }
610}