1use super::error::*;
24use asn1_der::typed::{DerDecodable, DerEncodable, DerTypeView, Sequence};
25use asn1_der::{Asn1DerError, Asn1DerErrorVariant, DerObject, Sink, VecBacking};
26use ring::rand::SystemRandom;
27use ring::signature::KeyPair;
28use ring::signature::{self, RsaKeyPair, RSA_PKCS1_2048_8192_SHA256, RSA_PKCS1_SHA256};
29use std::{fmt, sync::Arc};
30use zeroize::Zeroize;
31
32#[derive(Clone)]
34pub struct Keypair(Arc<RsaKeyPair>);
35
36impl std::fmt::Debug for Keypair {
37 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
38 f.debug_struct("Keypair")
39 .field("public", self.0.public_key())
40 .finish()
41 }
42}
43
44impl Keypair {
45 pub fn try_decode_pkcs1(der: &mut [u8]) -> Result<Keypair, DecodingError> {
50 let kp = RsaKeyPair::from_der(der)
51 .map_err(|e| DecodingError::failed_to_parse("RSA DER PKCS#1 RSAPrivateKey", e))?;
52 der.zeroize();
53 Ok(Keypair(Arc::new(kp)))
54 }
55
56 pub fn try_decode_pkcs8(der: &mut [u8]) -> Result<Keypair, DecodingError> {
61 let kp = RsaKeyPair::from_pkcs8(der)
62 .map_err(|e| DecodingError::failed_to_parse("RSA PKCS#8 PrivateKeyInfo", e))?;
63 der.zeroize();
64 Ok(Keypair(Arc::new(kp)))
65 }
66
67 pub fn public(&self) -> PublicKey {
69 PublicKey(self.0.public_key().as_ref().to_vec())
70 }
71
72 pub fn sign(&self, data: &[u8]) -> Result<Vec<u8>, SigningError> {
74 let mut signature = vec![0; self.0.public().modulus_len()];
75 let rng = SystemRandom::new();
76 match self.0.sign(&RSA_PKCS1_SHA256, &rng, data, &mut signature) {
77 Ok(()) => Ok(signature),
78 Err(e) => Err(SigningError::new("RSA").source(e)),
79 }
80 }
81}
82
83#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
85pub struct PublicKey(Vec<u8>);
86
87impl PublicKey {
88 pub fn verify(&self, msg: &[u8], sig: &[u8]) -> bool {
90 let key = signature::UnparsedPublicKey::new(&RSA_PKCS1_2048_8192_SHA256, &self.0);
91 key.verify(msg, sig).is_ok()
92 }
93
94 pub fn encode_pkcs1(&self) -> Vec<u8> {
99 self.0.clone()
101 }
102
103 pub fn encode_x509(&self) -> Vec<u8> {
108 let spki = Asn1SubjectPublicKeyInfo {
109 algorithmIdentifier: Asn1RsaEncryption {
110 algorithm: Asn1OidRsaEncryption,
111 parameters: (),
112 },
113 subjectPublicKey: Asn1SubjectPublicKey(self.clone()),
114 };
115 let mut buf = Vec::new();
116 spki.encode(&mut buf)
117 .map(|_| buf)
118 .expect("RSA X.509 public key encoding failed.")
119 }
120
121 pub fn try_decode_x509(pk: &[u8]) -> Result<PublicKey, DecodingError> {
124 Asn1SubjectPublicKeyInfo::decode(pk)
125 .map_err(|e| DecodingError::failed_to_parse("RSA X.509", e))
126 .map(|spki| spki.subjectPublicKey.0)
127 }
128}
129
130impl fmt::Debug for PublicKey {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 f.write_str("PublicKey(PKCS1): ")?;
133 for byte in &self.0 {
134 write!(f, "{byte:x}")?;
135 }
136 Ok(())
137 }
138}
139
140#[derive(Copy, Clone)]
148struct Asn1RawOid<'a> {
149 object: DerObject<'a>,
150}
151
152impl Asn1RawOid<'_> {
153 pub(crate) fn oid(&self) -> &[u8] {
155 self.object.value()
156 }
157
158 pub(crate) fn write<S: Sink>(value: &[u8], sink: &mut S) -> Result<(), Asn1DerError> {
160 DerObject::write(Self::TAG, value.len(), &mut value.iter(), sink)
161 }
162}
163
164impl<'a> DerTypeView<'a> for Asn1RawOid<'a> {
165 const TAG: u8 = 6;
166
167 fn object(&self) -> DerObject<'a> {
168 self.object
169 }
170}
171
172impl DerEncodable for Asn1RawOid<'_> {
173 fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
174 self.object.encode(sink)
175 }
176}
177
178impl<'a> DerDecodable<'a> for Asn1RawOid<'a> {
179 fn load(object: DerObject<'a>) -> Result<Self, Asn1DerError> {
180 if object.tag() != Self::TAG {
181 return Err(Asn1DerError::new(Asn1DerErrorVariant::InvalidData(
182 "DER object tag is not the object identifier tag.",
183 )));
184 }
185
186 Ok(Self { object })
187 }
188}
189
190#[derive(Clone)]
192struct Asn1OidRsaEncryption;
193
194impl Asn1OidRsaEncryption {
195 const OID: [u8; 9] = [0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01];
202}
203
204impl DerEncodable for Asn1OidRsaEncryption {
205 fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
206 Asn1RawOid::write(&Self::OID, sink)
207 }
208}
209
210impl DerDecodable<'_> for Asn1OidRsaEncryption {
211 fn load(object: DerObject<'_>) -> Result<Self, Asn1DerError> {
212 match Asn1RawOid::load(object)?.oid() {
213 oid if oid == Self::OID => Ok(Self),
214 _ => Err(Asn1DerError::new(Asn1DerErrorVariant::InvalidData(
215 "DER object is not the 'rsaEncryption' identifier.",
216 ))),
217 }
218 }
219}
220
221struct Asn1RsaEncryption {
223 algorithm: Asn1OidRsaEncryption,
224 parameters: (),
225}
226
227impl DerEncodable for Asn1RsaEncryption {
228 fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
229 let mut algorithm_buf = Vec::new();
230 let algorithm = self.algorithm.der_object(VecBacking(&mut algorithm_buf))?;
231
232 let mut parameters_buf = Vec::new();
233 let parameters = self
234 .parameters
235 .der_object(VecBacking(&mut parameters_buf))?;
236
237 Sequence::write(&[algorithm, parameters], sink)
238 }
239}
240
241impl DerDecodable<'_> for Asn1RsaEncryption {
242 fn load(object: DerObject<'_>) -> Result<Self, Asn1DerError> {
243 let seq: Sequence = Sequence::load(object)?;
244
245 Ok(Self {
246 algorithm: seq.get_as(0)?,
247 parameters: seq.get_as(1)?,
248 })
249 }
250}
251
252struct Asn1SubjectPublicKey(PublicKey);
255
256impl DerEncodable for Asn1SubjectPublicKey {
257 fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
258 let pk_der = &(self.0).0;
259 let mut bit_string = Vec::with_capacity(pk_der.len() + 1);
260 bit_string.push(0u8);
263 bit_string.extend(pk_der);
264 DerObject::write(3, bit_string.len(), &mut bit_string.iter(), sink)?;
265 Ok(())
266 }
267}
268
269impl DerDecodable<'_> for Asn1SubjectPublicKey {
270 fn load(object: DerObject<'_>) -> Result<Self, Asn1DerError> {
271 if object.tag() != 3 {
272 return Err(Asn1DerError::new(Asn1DerErrorVariant::InvalidData(
273 "DER object tag is not the bit string tag.",
274 )));
275 }
276
277 let pk_der: Vec<u8> = object.value().iter().skip(1).cloned().collect();
278 Ok(Self(PublicKey(pk_der)))
281 }
282}
283
284#[allow(non_snake_case)]
286struct Asn1SubjectPublicKeyInfo {
287 algorithmIdentifier: Asn1RsaEncryption,
288 subjectPublicKey: Asn1SubjectPublicKey,
289}
290
291impl DerEncodable for Asn1SubjectPublicKeyInfo {
292 fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
293 let mut identifier_buf = Vec::new();
294 let identifier = self
295 .algorithmIdentifier
296 .der_object(VecBacking(&mut identifier_buf))?;
297
298 let mut key_buf = Vec::new();
299 let key = self.subjectPublicKey.der_object(VecBacking(&mut key_buf))?;
300
301 Sequence::write(&[identifier, key], sink)
302 }
303}
304
305impl DerDecodable<'_> for Asn1SubjectPublicKeyInfo {
306 fn load(object: DerObject<'_>) -> Result<Self, Asn1DerError> {
307 let seq: Sequence = Sequence::load(object)?;
308
309 Ok(Self {
310 algorithmIdentifier: seq.get_as(0)?,
311 subjectPublicKey: seq.get_as(1)?,
312 })
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use quickcheck::*;
320
321 const KEY1: &[u8] = include_bytes!("test/rsa-2048.pk8");
322 const KEY2: &[u8] = include_bytes!("test/rsa-3072.pk8");
323 const KEY3: &[u8] = include_bytes!("test/rsa-4096.pk8");
324
325 #[derive(Clone, Debug)]
326 struct SomeKeypair(Keypair);
327
328 impl Arbitrary for SomeKeypair {
329 fn arbitrary(g: &mut Gen) -> SomeKeypair {
330 let mut key = g.choose(&[KEY1, KEY2, KEY3]).unwrap().to_vec();
331 SomeKeypair(Keypair::try_decode_pkcs8(&mut key).unwrap())
332 }
333 }
334
335 #[test]
336 fn rsa_from_pkcs8() {
337 assert!(Keypair::try_decode_pkcs8(&mut KEY1.to_vec()).is_ok());
338 assert!(Keypair::try_decode_pkcs8(&mut KEY2.to_vec()).is_ok());
339 assert!(Keypair::try_decode_pkcs8(&mut KEY3.to_vec()).is_ok());
340 }
341
342 #[test]
343 fn rsa_x509_encode_decode() {
344 fn prop(SomeKeypair(kp): SomeKeypair) -> Result<bool, String> {
345 let pk = kp.public();
346 PublicKey::try_decode_x509(&pk.encode_x509())
347 .map_err(|e| e.to_string())
348 .map(|pk2| pk2 == pk)
349 }
350 QuickCheck::new().tests(10).quickcheck(prop as fn(_) -> _);
351 }
352
353 #[test]
354 fn rsa_sign_verify() {
355 fn prop(SomeKeypair(kp): SomeKeypair, msg: Vec<u8>) -> Result<bool, SigningError> {
356 kp.sign(&msg).map(|s| kp.public().verify(&msg, &s))
357 }
358 QuickCheck::new()
359 .tests(10)
360 .quickcheck(prop as fn(_, _) -> _);
361 }
362}