use std::cmp::min;
use std::collections::BTreeMap;
use std::iter::FromIterator;
use std::marker::PhantomData;
use fedimint_core::encoding::{Decodable, DecodeError, Encodable};
use serde::{Deserialize, Serialize};
use crate::module::registry::ModuleDecoderRegistry;
use crate::tiered::InvalidAmountTierError;
use crate::{Amount, Tiered};
#[derive(Debug, Clone, Eq, PartialEq, Hash, Deserialize, Serialize)]
pub struct TieredMulti<T>(BTreeMap<Amount, Vec<T>>);
impl<T> TieredMulti<T> {
pub fn new(map: BTreeMap<Amount, Vec<T>>) -> Self {
TieredMulti(map)
}
pub fn total_amount(&self) -> Amount {
let milli_sat = self
.0
.iter()
.map(|(tier, notes)| tier.msats * (notes.len() as u64))
.sum();
Amount::from_msats(milli_sat)
}
pub fn count_items(&self) -> usize {
self.0.values().map(|notes| notes.len()).sum()
}
pub fn count_tiers(&self) -> usize {
self.0.len()
}
pub fn iter_tiers(&self) -> impl Iterator<Item = &Amount> {
self.0.keys()
}
pub fn summary(&self) -> TieredSummary {
TieredSummary(Tiered::from_iter(
self.iter().map(|(amount, values)| (*amount, values.len())),
))
}
pub fn is_empty(&self) -> bool {
self.count_items() == 0
}
pub fn structural_eq<O>(&self, other: &TieredMulti<O>) -> bool {
let tier_eq = self.0.keys().eq(other.0.keys());
let per_tier_eq = self
.0
.values()
.zip(other.0.values())
.all(|(c1, c2)| c1.len() == c2.len());
tier_eq && per_tier_eq
}
pub fn iter(&self) -> impl Iterator<Item = (&Amount, &Vec<T>)> {
self.0.iter()
}
pub fn iter_items(&self) -> impl Iterator<Item = (Amount, &T)> + DoubleEndedIterator {
self.0
.iter()
.flat_map(|(amt, notes)| notes.iter().map(move |c| (*amt, c)))
}
pub fn into_iter_items(self) -> impl Iterator<Item = (Amount, T)> + DoubleEndedIterator {
self.0
.into_iter()
.flat_map(|(amt, notes)| notes.into_iter().map(move |c| (amt, c)))
}
pub fn longest_tier_except(&self, except: &Amount) -> usize {
self.0
.iter()
.filter_map(|(amt, notes)| {
if amt != except {
Some(notes.len())
} else {
None
}
})
.max()
.unwrap_or_default()
}
pub fn all_tiers_exist_in<K>(&self, keys: &Tiered<K>) -> Result<(), InvalidAmountTierError> {
match self.0.keys().find(|&amt| keys.get(*amt).is_none()) {
Some(amt) => Err(InvalidAmountTierError(*amt)),
None => Ok(()),
}
}
pub fn get(&self, amt: Amount) -> Option<&Vec<T>> {
self.0.get(&amt)
}
pub fn get_mut(&mut self, amt: Amount) -> Option<&mut Vec<T>> {
self.0.get_mut(&amt)
}
}
impl<C> FromIterator<(Amount, C)> for TieredMulti<C> {
fn from_iter<T: IntoIterator<Item = (Amount, C)>>(iter: T) -> Self {
let mut res = TieredMulti::default();
res.extend(iter);
res
}
}
impl<C> IntoIterator for TieredMulti<C>
where
C: 'static + Send,
{
type Item = (Amount, C);
type IntoIter = Box<dyn Iterator<Item = (Amount, C)> + Send>;
fn into_iter(self) -> Self::IntoIter {
Box::new(
self.0
.into_iter()
.flat_map(|(amt, notes)| notes.into_iter().map(move |c| (amt, c))),
)
}
}
impl<C> Default for TieredMulti<C> {
fn default() -> Self {
TieredMulti(BTreeMap::default())
}
}
impl<C> Extend<(Amount, C)> for TieredMulti<C> {
fn extend<T: IntoIterator<Item = (Amount, C)>>(&mut self, iter: T) {
for (amount, note) in iter {
self.0.entry(amount).or_default().push(note)
}
}
}
impl<C> Encodable for TieredMulti<C>
where
C: Encodable,
{
fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> {
self.0.consensus_encode(writer)
}
}
impl<C> Decodable for TieredMulti<C>
where
C: Decodable,
{
fn consensus_decode<D: std::io::Read>(
d: &mut D,
modules: &ModuleDecoderRegistry,
) -> Result<Self, DecodeError> {
Ok(TieredMulti(BTreeMap::consensus_decode(d, modules)?))
}
}
pub struct TieredMultiZip<'a, I, T>
where
I: 'a,
{
iters: Vec<I>,
_pd: PhantomData<&'a T>,
}
impl<'a, I, C> TieredMultiZip<'a, I, C> {
pub fn new(iters: Vec<I>) -> Self {
assert!(!iters.is_empty());
TieredMultiZip {
iters,
_pd: Default::default(),
}
}
}
impl<'a, I, C> Iterator for TieredMultiZip<'a, I, C>
where
I: Iterator<Item = (Amount, C)>,
{
type Item = (Amount, Vec<C>);
fn next(&mut self) -> Option<Self::Item> {
let mut notes = Vec::with_capacity(self.iters.len());
let mut amount = None;
for iter in self.iters.iter_mut() {
match iter.next() {
Some((amt, note)) => {
if let Some(amount) = amount {
assert_eq!(amount, amt);
} else {
amount = Some(amt);
}
notes.push(note);
}
None => return None,
}
}
assert_eq!(notes.len(), self.iters.len());
Some((
amount.expect("The multi zip must contain at least one iterator"),
notes,
))
}
}
#[derive(Debug, PartialEq, Default, Serialize, Deserialize, Clone)]
pub struct TieredSummary(Tiered<usize>);
impl TieredSummary {
pub fn represent_amount<K>(
amount: Amount,
current_denominations: &TieredSummary,
tiers: &Tiered<K>,
denomination_sets: u16,
) -> TieredSummary {
let mut remaining_amount = amount;
let mut denominations = TieredSummary::default();
for tier in tiers.tiers() {
let notes = current_denominations
.0
.get(*tier)
.copied()
.unwrap_or_default();
let missing_notes = (denomination_sets as u64).saturating_sub(notes as u64);
let possible_notes = remaining_amount / *tier;
let add_notes = min(possible_notes, missing_notes);
denominations.inc(*tier, add_notes as usize);
remaining_amount -= *tier * add_notes;
}
for tier in tiers.tiers().rev() {
let res = remaining_amount / *tier;
remaining_amount %= *tier;
denominations.inc(*tier, res as usize);
}
let represented: u64 = denominations
.0
.iter()
.map(|(k, v)| k.msats * (*v as u64))
.sum();
assert_eq!(represented, amount.msats);
denominations
}
pub fn inc(&mut self, tier: Amount, n: usize) {
*self.0.get_mut_or_default(tier) += n;
}
pub fn iter(&self) -> impl Iterator<Item = (Amount, usize)> + '_ {
self.0.iter().map(|(k, v)| (k, *v))
}
pub fn total_amount(&self) -> Amount {
self.0.iter().map(|(k, v)| k * (*v as u64)).sum::<Amount>()
}
pub fn count_items(&self) -> usize {
self.0.iter().map(|(_, v)| *v).sum()
}
pub fn count_tiers(&self) -> usize {
self.0.count_tiers()
}
}
impl FromIterator<(Amount, usize)> for TieredSummary {
fn from_iter<I: IntoIterator<Item = (Amount, usize)>>(iter: I) -> Self {
TieredSummary(iter.into_iter().collect())
}
}
#[cfg(test)]
mod test {
use fedimint_core::Amount;
use super::*;
#[test]
fn summary_works() {
let notes = notes(vec![
(Amount::from_sats(1), 1),
(Amount::from_sats(2), 3),
(Amount::from_sats(3), 2),
]);
let summary = notes.summary();
assert_eq!(
summary.iter().collect::<Vec<_>>(),
vec![
(Amount::from_sats(1), 1),
(Amount::from_sats(2), 3),
(Amount::from_sats(3), 2),
]
);
assert_eq!(summary.total_amount(), notes.total_amount());
assert_eq!(summary.count_items(), notes.count_items());
assert_eq!(summary.count_tiers(), notes.count_tiers());
}
#[test]
fn represent_amount_targets_denomination_sets() {
let starting = notes(vec![
(Amount::from_sats(1), 1),
(Amount::from_sats(2), 3),
(Amount::from_sats(3), 2),
])
.summary();
let tiers = tiers(vec![1, 2, 3, 4]);
assert_eq!(
TieredSummary::represent_amount(Amount::from_sats(6), &starting, &tiers, 3),
denominations(vec![
(Amount::from_sats(1), 3),
(Amount::from_sats(2), 0),
(Amount::from_sats(3), 1),
(Amount::from_sats(4), 0)
])
);
assert_eq!(
TieredSummary::represent_amount(Amount::from_sats(6), &starting, &tiers, 2),
denominations(vec![
(Amount::from_sats(1), 2),
(Amount::from_sats(2), 0),
(Amount::from_sats(3), 0),
(Amount::from_sats(4), 1)
])
);
}
fn notes(notes: Vec<(Amount, usize)>) -> TieredMulti<usize> {
notes
.into_iter()
.flat_map(|(amount, number)| vec![(amount, 0usize); number])
.collect()
}
fn tiers(tiers: Vec<u64>) -> Tiered<()> {
tiers
.into_iter()
.map(|tier| (Amount::from_sats(tier), ()))
.collect()
}
fn denominations(denominations: Vec<(Amount, usize)>) -> TieredSummary {
TieredSummary::from_iter(denominations)
}
}