1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
pub(super) mod xmd;
pub(super) mod xof;
use crate::{Error, Result};
use digest::{Digest, ExtendableOutput, Update, XofReader};
use generic_array::typenum::{IsLess, U256};
use generic_array::{ArrayLength, GenericArray};
const OVERSIZE_DST_SALT: &[u8] = b"H2C-OVERSIZE-DST-";
const MAX_DST_LEN: usize = 255;
pub trait ExpandMsg<'a> {
type Expander: Expander + Sized;
fn expand_message(msgs: &[&[u8]], dst: &'a [u8], len_in_bytes: usize)
-> Result<Self::Expander>;
}
pub trait Expander {
fn fill_bytes(&mut self, okm: &mut [u8]);
}
pub(crate) enum Domain<'a, L>
where
L: ArrayLength<u8> + IsLess<U256>,
{
Hashed(GenericArray<u8, L>),
Array(&'a [u8]),
}
impl<'a, L> Domain<'a, L>
where
L: ArrayLength<u8> + IsLess<U256>,
{
pub fn xof<X>(dst: &'a [u8]) -> Result<Self>
where
X: Default + ExtendableOutput + Update,
{
if dst.is_empty() {
Err(Error)
} else if dst.len() > MAX_DST_LEN {
let mut data = GenericArray::<u8, L>::default();
X::default()
.chain(OVERSIZE_DST_SALT)
.chain(dst)
.finalize_xof()
.read(&mut data);
Ok(Self::Hashed(data))
} else {
Ok(Self::Array(dst))
}
}
pub fn xmd<X>(dst: &'a [u8]) -> Result<Self>
where
X: Digest<OutputSize = L>,
{
if dst.is_empty() {
Err(Error)
} else if dst.len() > MAX_DST_LEN {
Ok(Self::Hashed({
let mut hash = X::new();
hash.update(OVERSIZE_DST_SALT);
hash.update(dst);
hash.finalize()
}))
} else {
Ok(Self::Array(dst))
}
}
pub fn data(&self) -> &[u8] {
match self {
Self::Hashed(d) => &d[..],
Self::Array(d) => *d,
}
}
pub fn len(&self) -> u8 {
match self {
Self::Hashed(_) => L::to_u8(),
Self::Array(d) => u8::try_from(d.len()).expect("length overflow"),
}
}
#[cfg(test)]
pub fn assert(&self, bytes: &[u8]) {
assert_eq!(self.data(), &bytes[..bytes.len() - 1]);
assert_eq!(self.len(), bytes[bytes.len() - 1]);
}
}