1use std::collections::{BTreeMap, BTreeSet};
2use std::fmt::Debug;
3use std::mem;
4
5use anyhow::anyhow;
6use fedimint_core::task::{MaybeSend, MaybeSync};
7use fedimint_core::{maybe_add_send_sync, NumPeers, PeerId};
8
9use crate::api::{self, PeerError, PeerResult};
10
11pub trait QueryStrategy<IR, OR = IR> {
18 fn process(&mut self, peer_id: PeerId, response: api::PeerResult<IR>) -> QueryStep<OR>;
19}
20
21#[derive(Debug)]
27pub enum QueryStep<R> {
28 Retry(BTreeSet<PeerId>),
30 Continue,
32 Success(R),
34 Failure {
36 general: Option<anyhow::Error>,
37 peers: BTreeMap<PeerId, PeerError>,
38 },
39}
40
41struct ErrorStrategy {
42 errors: BTreeMap<PeerId, PeerError>,
43 threshold: usize,
44}
45
46impl ErrorStrategy {
47 pub fn new(threshold: usize) -> Self {
48 assert!(threshold > 0);
49
50 Self {
51 errors: BTreeMap::new(),
52 threshold,
53 }
54 }
55
56 fn format_errors(&self) -> String {
57 use std::fmt::Write;
58 self.errors
59 .iter()
60 .fold(String::new(), |mut s, (peer_id, e)| {
61 if !s.is_empty() {
62 write!(s, ", ").expect("can't fail");
63 }
64 write!(s, "peer-{peer_id}: {e}").expect("can't fail");
65
66 s
67 })
68 }
69
70 pub fn process<R>(&mut self, peer: PeerId, error: PeerError) -> QueryStep<R> {
71 assert!(self.errors.insert(peer, error).is_none());
72
73 if self.errors.len() == self.threshold {
74 QueryStep::Failure {
75 general: Some(anyhow!(
76 "Received errors from {} peers: {}",
77 self.threshold,
78 self.format_errors()
79 )),
80 peers: mem::take(&mut self.errors),
81 }
82 } else {
83 QueryStep::Continue
84 }
85 }
86}
87
88pub struct FilterMap<R, T> {
91 filter_map: Box<maybe_add_send_sync!(dyn Fn(R) -> anyhow::Result<T>)>,
92 error_strategy: ErrorStrategy,
93}
94
95impl<R, T> FilterMap<R, T> {
96 pub fn new(
97 filter_map: impl Fn(R) -> anyhow::Result<T> + MaybeSend + MaybeSync + 'static,
98 num_peers: NumPeers,
99 ) -> Self {
100 Self {
101 filter_map: Box::new(filter_map),
102 error_strategy: ErrorStrategy::new(num_peers.threshold()),
103 }
104 }
105}
106
107impl<R, T> QueryStrategy<R, T> for FilterMap<R, T> {
108 fn process(&mut self, peer: PeerId, result: PeerResult<R>) -> QueryStep<T> {
109 match result {
110 Ok(response) => match (self.filter_map)(response) {
111 Ok(value) => QueryStep::Success(value),
112 Err(error) => self
113 .error_strategy
114 .process(peer, PeerError::InvalidResponse(error.to_string())),
115 },
116 Err(error) => self.error_strategy.process(peer, error),
117 }
118 }
119}
120
121pub struct FilterMapThreshold<R, T> {
124 filter_map: Box<maybe_add_send_sync!(dyn Fn(PeerId, R) -> anyhow::Result<T>)>,
125 error_strategy: ErrorStrategy,
126 filtered_responses: BTreeMap<PeerId, T>,
127 threshold: usize,
128}
129
130impl<R, T> FilterMapThreshold<R, T> {
131 pub fn new(
132 verifier: impl Fn(PeerId, R) -> anyhow::Result<T> + MaybeSend + MaybeSync + 'static,
133 num_peers: NumPeers,
134 ) -> Self {
135 Self {
136 filter_map: Box::new(verifier),
137 error_strategy: ErrorStrategy::new(num_peers.one_honest()),
138 filtered_responses: BTreeMap::new(),
139 threshold: num_peers.threshold(),
140 }
141 }
142}
143
144impl<R, T> QueryStrategy<R, BTreeMap<PeerId, T>> for FilterMapThreshold<R, T> {
145 fn process(&mut self, peer: PeerId, result: PeerResult<R>) -> QueryStep<BTreeMap<PeerId, T>> {
146 match result {
147 Ok(response) => match (self.filter_map)(peer, response) {
148 Ok(response) => {
149 self.filtered_responses.insert(peer, response);
150
151 if self.filtered_responses.len() == self.threshold {
152 QueryStep::Success(mem::take(&mut self.filtered_responses))
153 } else {
154 QueryStep::Continue
155 }
156 }
157 Err(error) => self
158 .error_strategy
159 .process(peer, PeerError::InvalidResponse(error.to_string())),
160 },
161 Err(error) => self.error_strategy.process(peer, error),
162 }
163 }
164}
165
166pub struct ThresholdConsensus<R> {
171 error_strategy: ErrorStrategy,
172 responses: BTreeMap<PeerId, R>,
173 retry: BTreeSet<PeerId>,
174 threshold: usize,
175}
176
177impl<R> ThresholdConsensus<R> {
178 pub fn new(num_peers: NumPeers) -> Self {
179 Self {
180 error_strategy: ErrorStrategy::new(num_peers.one_honest()),
181 responses: BTreeMap::new(),
182 retry: BTreeSet::new(),
183 threshold: num_peers.threshold(),
184 }
185 }
186}
187
188impl<R: Eq> QueryStrategy<R> for ThresholdConsensus<R> {
189 fn process(&mut self, peer: PeerId, result: PeerResult<R>) -> QueryStep<R> {
190 match result {
191 Ok(response) => {
192 let current_count = self.responses.values().filter(|r| **r == response).count();
193
194 if current_count + 1 >= self.threshold {
195 return QueryStep::Success(response);
196 }
197
198 self.responses.insert(peer, response);
199
200 assert!(self.retry.insert(peer));
201
202 if self.retry.len() == self.threshold {
203 QueryStep::Retry(mem::take(&mut self.retry))
204 } else {
205 QueryStep::Continue
206 }
207 }
208 Err(error) => self.error_strategy.process(peer, error),
209 }
210 }
211}