1#![allow(clippy::mutable_key_type)]
20
21use std::{collections::HashMap, fmt, slice::Iter, str::FromStr};
24
25use bytes::Bytes;
26use serde::{Deserialize, Serialize};
27
28#[derive(Clone, PartialEq, Eq, Debug, Default, Deserialize, Serialize)]
48pub struct HeaderMap {
49 inner: HashMap<HeaderName, Vec<HeaderValue>>,
50}
51
52impl FromIterator<(HeaderName, HeaderValue)> for HeaderMap {
53 fn from_iter<T: IntoIterator<Item = (HeaderName, HeaderValue)>>(iter: T) -> Self {
54 let mut header_map = HeaderMap::new();
55 for (key, value) in iter {
56 header_map.insert(key, value);
57 }
58 header_map
59 }
60}
61
62impl HeaderMap {
63 pub fn iter(&self) -> std::collections::hash_map::Iter<'_, HeaderName, Vec<HeaderValue>> {
64 self.inner.iter()
65 }
66}
67
68pub struct GetAll<'a, T> {
69 inner: Iter<'a, T>,
70}
71
72impl<'a, T> Iterator for GetAll<'a, T> {
73 type Item = &'a T;
74
75 fn next(&mut self) -> Option<Self::Item> {
76 self.inner.next()
77 }
78}
79
80impl HeaderMap {
81 pub fn new() -> Self {
92 HeaderMap::default()
93 }
94
95 pub fn is_empty(&self) -> bool {
111 self.inner.is_empty()
112 }
113
114 pub fn len(&self) -> usize {
115 self.inner.len()
116 }
117}
118
119impl HeaderMap {
120 pub fn insert<K: IntoHeaderName, V: IntoHeaderValue>(&mut self, name: K, value: V) {
131 self.inner
132 .insert(name.into_header_name(), vec![value.into_header_value()]);
133 }
134
135 pub fn append<K: IntoHeaderName, V: IntoHeaderValue>(&mut self, name: K, value: V) {
147 let key = name.into_header_name();
148 let v = self.inner.get_mut(&key);
149 match v {
150 Some(v) => {
151 v.push(value.into_header_value());
152 }
153 None => {
154 self.insert(key, value.into_header_value());
155 }
156 }
157 }
158
159 pub fn get<K: IntoHeaderName>(&self, key: K) -> Option<&HeaderValue> {
171 self.inner
172 .get(&key.into_header_name())
173 .and_then(|x| x.first())
174 }
175
176 pub fn get_last<K: IntoHeaderName>(&self, key: K) -> Option<&HeaderValue> {
188 self.inner
189 .get(&key.into_header_name())
190 .and_then(|x| x.last())
191 }
192
193 pub fn get_all<K: IntoHeaderName>(&self, key: K) -> GetAll<HeaderValue> {
208 let inner = self
209 .inner
210 .get(&key.into_header_name())
211 .map(|x| x.iter())
212 .unwrap_or([].iter());
213
214 GetAll { inner }
215 }
216
217 pub(crate) fn to_bytes(&self) -> Vec<u8> {
218 let mut buf = vec![];
219 buf.extend_from_slice(b"NATS/1.0\r\n");
220 for (k, vs) in &self.inner {
221 for v in vs.iter() {
222 buf.extend_from_slice(k.as_str().as_bytes());
223 buf.extend_from_slice(b": ");
224 buf.extend_from_slice(v.inner.as_bytes());
225 buf.extend_from_slice(b"\r\n");
226 }
227 }
228 buf.extend_from_slice(b"\r\n");
229 buf
230 }
231}
232
233#[derive(Clone, PartialEq, Eq, Debug, Default, Serialize, Deserialize)]
245pub struct HeaderValue {
246 inner: String,
247}
248
249impl fmt::Display for HeaderValue {
250 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251 fmt::Display::fmt(&self.as_str(), f)
252 }
253}
254
255impl AsRef<[u8]> for HeaderValue {
256 fn as_ref(&self) -> &[u8] {
257 self.inner.as_ref()
258 }
259}
260
261impl AsRef<str> for HeaderValue {
262 fn as_ref(&self) -> &str {
263 self.as_str()
264 }
265}
266
267impl From<i16> for HeaderValue {
268 fn from(v: i16) -> Self {
269 Self {
270 inner: v.to_string(),
271 }
272 }
273}
274
275impl From<i32> for HeaderValue {
276 fn from(v: i32) -> Self {
277 Self {
278 inner: v.to_string(),
279 }
280 }
281}
282
283impl From<i64> for HeaderValue {
284 fn from(v: i64) -> Self {
285 Self {
286 inner: v.to_string(),
287 }
288 }
289}
290
291impl From<isize> for HeaderValue {
292 fn from(v: isize) -> Self {
293 Self {
294 inner: v.to_string(),
295 }
296 }
297}
298
299impl From<u16> for HeaderValue {
300 fn from(v: u16) -> Self {
301 Self {
302 inner: v.to_string(),
303 }
304 }
305}
306
307impl From<u32> for HeaderValue {
308 fn from(v: u32) -> Self {
309 Self {
310 inner: v.to_string(),
311 }
312 }
313}
314
315impl From<u64> for HeaderValue {
316 fn from(v: u64) -> Self {
317 Self {
318 inner: v.to_string(),
319 }
320 }
321}
322
323impl From<usize> for HeaderValue {
324 fn from(v: usize) -> Self {
325 Self {
326 inner: v.to_string(),
327 }
328 }
329}
330
331impl FromStr for HeaderValue {
332 type Err = ParseHeaderValueError;
333
334 fn from_str(s: &str) -> Result<Self, Self::Err> {
335 if s.contains(['\r', '\n']) {
336 return Err(ParseHeaderValueError);
337 }
338
339 Ok(HeaderValue {
340 inner: s.to_string(),
341 })
342 }
343}
344
345impl From<&str> for HeaderValue {
346 fn from(v: &str) -> Self {
347 Self {
348 inner: v.to_string(),
349 }
350 }
351}
352
353impl From<String> for HeaderValue {
354 fn from(inner: String) -> Self {
355 Self { inner }
356 }
357}
358
359impl HeaderValue {
360 pub fn new() -> Self {
361 HeaderValue::default()
362 }
363
364 pub fn as_str(&self) -> &str {
365 self.inner.as_str()
366 }
367}
368
369#[derive(Debug, Clone)]
370pub struct ParseHeaderValueError;
371
372impl fmt::Display for ParseHeaderValueError {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 write!(
375 f,
376 r#"invalid character found in header value (value cannot contain '\r' or '\n')"#
377 )
378 }
379}
380
381impl std::error::Error for ParseHeaderValueError {}
382
383pub trait IntoHeaderName {
384 fn into_header_name(self) -> HeaderName;
385}
386
387impl IntoHeaderName for &str {
388 fn into_header_name(self) -> HeaderName {
389 HeaderName {
390 inner: HeaderRepr::Custom(self.into()),
391 }
392 }
393}
394
395impl IntoHeaderName for String {
396 fn into_header_name(self) -> HeaderName {
397 HeaderName {
398 inner: HeaderRepr::Custom(self.into()),
399 }
400 }
401}
402
403impl IntoHeaderName for HeaderName {
404 fn into_header_name(self) -> HeaderName {
405 self
406 }
407}
408
409pub trait IntoHeaderValue {
410 fn into_header_value(self) -> HeaderValue;
411}
412
413impl IntoHeaderValue for &str {
414 fn into_header_value(self) -> HeaderValue {
415 HeaderValue {
416 inner: self.to_string(),
417 }
418 }
419}
420
421impl IntoHeaderValue for String {
422 fn into_header_value(self) -> HeaderValue {
423 HeaderValue { inner: self }
424 }
425}
426
427impl IntoHeaderValue for HeaderValue {
428 fn into_header_value(self) -> HeaderValue {
429 self
430 }
431}
432
433macro_rules! standard_headers {
434 (
435 $(
436 $(#[$docs:meta])*
437 ($variant:ident, $constant:ident, $bytes:literal);
438 )+
439 ) => {
440 #[allow(clippy::enum_variant_names)]
441 #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
442 enum StandardHeader {
443 $(
444 $variant,
445 )+
446 }
447
448 $(
449 $(#[$docs])*
450 pub const $constant: HeaderName = HeaderName {
451 inner: HeaderRepr::Standard(StandardHeader::$variant),
452 };
453 )+
454
455 impl StandardHeader {
456 #[inline]
457 fn as_str(&self) -> &'static str {
458 match *self {
459 $(
460 StandardHeader::$variant => unsafe { std::str::from_utf8_unchecked( $bytes ) },
461 )+
462 }
463 }
464
465 const fn from_bytes(bytes: &[u8]) -> Option<StandardHeader> {
466 match bytes {
467 $(
468 $bytes => Some(StandardHeader::$variant),
469 )+
470 _ => None,
471 }
472 }
473 }
474
475 #[cfg(test)]
476 mod standard_header_tests {
477 use super::HeaderName;
478 use std::str::{self, FromStr};
479
480 const TEST_HEADERS: &'static [(&'static HeaderName, &'static [u8])] = &[
481 $(
482 (&super::$constant, $bytes),
483 )+
484 ];
485
486 #[test]
487 fn from_str() {
488 for &(header, bytes) in TEST_HEADERS {
489 let utf8 = str::from_utf8(bytes).expect("string constants isn't utf8");
490 assert_eq!(HeaderName::from_str(utf8).unwrap(), *header);
491 }
492 }
493 }
494 }
495}
496
497standard_headers! {
499 (NatsStream, NATS_STREAM, b"Nats-Stream");
501 (NatsSequence, NATS_SEQUENCE, b"Nats-Sequence");
503 (NatsTimeStamp, NATS_TIME_STAMP, b"Nats-Time-Stamp");
505 (NatsSubject, NATS_SUBJECT, b"Nats-Subject");
507 (NatsMessageId, NATS_MESSAGE_ID, b"Nats-Msg-Id");
509 (NatsLastStream, NATS_LAST_STREAM, b"Nats-Last-Stream");
511 (NatsLastConsumer, NATS_LAST_CONSUMER, b"Nats-Last-Consumer");
513 (NatsLastSequence, NATS_LAST_SEQUENCE, b"Nats-Last-Sequence");
515 (NatsExpectedLastSubjectSequence, NATS_EXPECTED_LAST_SUBJECT_SEQUENCE, b"Nats-Expected-Last-Subject-Sequence");
517 (NatsExpectedLastMessageId, NATS_EXPECTED_LAST_MESSAGE_ID, b"Nats-Expected-Last-Msg-Id");
519 (NatsExpectedLastSequence, NATS_EXPECTED_LAST_SEQUENCE, b"Nats-Expected-Last-Sequence");
521 (NatsExpectedStream, NATS_EXPECTED_STREAM, b"Nats-Expected-Stream");
523}
524
525#[derive(Debug, Hash, PartialEq, Eq, Clone)]
526struct CustomHeader {
527 bytes: Bytes,
528}
529
530impl CustomHeader {
531 #[inline]
532 pub(crate) const fn from_static(value: &'static str) -> CustomHeader {
533 CustomHeader {
534 bytes: Bytes::from_static(value.as_bytes()),
535 }
536 }
537
538 #[inline]
539 pub(crate) fn as_str(&self) -> &str {
540 unsafe { std::str::from_utf8_unchecked(self.bytes.as_ref()) }
541 }
542}
543
544impl From<String> for CustomHeader {
545 #[inline]
546 fn from(value: String) -> CustomHeader {
547 CustomHeader {
548 bytes: Bytes::from(value),
549 }
550 }
551}
552
553impl<'a> From<&'a str> for CustomHeader {
554 #[inline]
555 fn from(value: &'a str) -> CustomHeader {
556 CustomHeader {
557 bytes: Bytes::copy_from_slice(value.as_bytes()),
558 }
559 }
560}
561
562#[derive(Debug, Hash, PartialEq, Eq, Clone)]
563enum HeaderRepr {
564 Standard(StandardHeader),
565 Custom(CustomHeader),
566}
567
568#[derive(Clone, PartialEq, Eq, Hash, Debug)]
578pub struct HeaderName {
579 inner: HeaderRepr,
580}
581
582impl HeaderName {
583 #[inline]
585 pub const fn from_static(value: &'static str) -> HeaderName {
586 if let Some(standard) = StandardHeader::from_bytes(value.as_bytes()) {
587 return HeaderName {
588 inner: HeaderRepr::Standard(standard),
589 };
590 }
591
592 HeaderName {
593 inner: HeaderRepr::Custom(CustomHeader::from_static(value)),
594 }
595 }
596
597 #[inline]
599 fn as_str(&self) -> &str {
600 match self.inner {
601 HeaderRepr::Standard(v) => v.as_str(),
602 HeaderRepr::Custom(ref v) => v.as_str(),
603 }
604 }
605}
606
607impl FromStr for HeaderName {
608 type Err = ParseHeaderNameError;
609
610 fn from_str(s: &str) -> Result<Self, Self::Err> {
611 if s.contains(|c: char| c == ':' || (c as u8) < 33 || (c as u8) > 126) {
612 return Err(ParseHeaderNameError);
613 }
614
615 match StandardHeader::from_bytes(s.as_ref()) {
616 Some(v) => Ok(HeaderName {
617 inner: HeaderRepr::Standard(v),
618 }),
619 None => Ok(HeaderName {
620 inner: HeaderRepr::Custom(CustomHeader::from(s)),
621 }),
622 }
623 }
624}
625
626impl fmt::Display for HeaderName {
627 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
628 fmt::Display::fmt(&self.as_str(), f)
629 }
630}
631
632impl AsRef<[u8]> for HeaderName {
633 fn as_ref(&self) -> &[u8] {
634 self.as_str().as_bytes()
635 }
636}
637
638impl AsRef<str> for HeaderName {
639 fn as_ref(&self) -> &str {
640 self.as_str()
641 }
642}
643
644impl Serialize for HeaderName {
645 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
646 where
647 S: serde::Serializer,
648 {
649 serializer.serialize_str(self.as_str())
650 }
651}
652
653impl<'de> Deserialize<'de> for HeaderName {
654 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
655 where
656 D: serde::Deserializer<'de>,
657 {
658 String::deserialize(deserializer)?
659 .parse()
660 .map_err(serde::de::Error::custom)
661 }
662}
663
664#[derive(Debug, Clone)]
665pub struct ParseHeaderNameError;
666
667impl std::fmt::Display for ParseHeaderNameError {
668 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
669 write!(f, "invalid header name (name cannot contain non-ascii alphanumeric characters other than '-')")
670 }
671}
672
673impl std::error::Error for ParseHeaderNameError {}
674
675#[cfg(test)]
676mod tests {
677 use super::{HeaderMap, HeaderName, HeaderValue, IntoHeaderName, IntoHeaderValue};
678 use std::str::{from_utf8, FromStr};
679
680 #[test]
681 fn try_from() {
682 let mut headers = HeaderMap::new();
683 headers.insert("name", "something".parse::<HeaderValue>().unwrap());
684 headers.insert("name", "something2");
685 }
686
687 #[test]
688 fn append() {
689 let mut headers = HeaderMap::new();
690 headers.append("Key", "value");
691 headers.append("Key", "second_value");
692
693 let mut result = headers.get_all("Key");
694
695 assert_eq!(
696 result.next().unwrap(),
697 &HeaderValue::from_str("value").unwrap()
698 );
699
700 assert_eq!(
701 result.next().unwrap(),
702 &HeaderValue::from_str("second_value").unwrap()
703 );
704
705 assert_eq!(result.next(), None);
706 }
707
708 #[test]
709 fn get_string() {
710 let mut headers = HeaderMap::new();
711 headers.append("Key", "value");
712 headers.append("Key", "other");
713
714 assert_eq!(headers.get("Key").unwrap().to_string(), "value");
715
716 let key: String = headers.get("Key").unwrap().as_str().into();
717 assert_eq!(key, "value".to_string());
718
719 let key: String = headers.get("Key").unwrap().as_str().to_owned();
720 assert_eq!(key, "value".to_string());
721
722 assert_eq!(headers.get("Key").unwrap().as_str(), "value");
723
724 let key: String = headers.get_last("Key").unwrap().as_str().into();
725 assert_eq!(key, "other".to_string());
726 }
727
728 #[test]
729 fn insert() {
730 let mut headers = HeaderMap::new();
731 headers.insert("Key", "Value");
732
733 let mut result = headers.get_all("Key");
734
735 assert_eq!(
736 result.next().unwrap(),
737 &HeaderValue::from_str("Value").unwrap()
738 );
739 assert_eq!(result.next(), None);
740 }
741
742 #[test]
743 fn serialize() {
744 let mut headers = HeaderMap::new();
745 headers.append("Key", "value");
746 headers.append("Key", "second_value");
747 headers.insert("Second", "SecondValue");
748
749 let bytes = headers.to_bytes();
750
751 println!("bytes: {:?}", from_utf8(&bytes));
752 }
753
754 #[test]
755 fn is_empty() {
756 let mut headers = HeaderMap::new();
757 assert!(headers.is_empty());
758
759 headers.append("Key", "value");
760 headers.append("Key", "second_value");
761 headers.insert("Second", "SecondValue");
762 assert!(!headers.is_empty());
763 }
764
765 #[test]
766 fn parse_value() {
767 assert!("Foo\r".parse::<HeaderValue>().is_err());
768 assert!("Foo\n".parse::<HeaderValue>().is_err());
769 assert!("Foo\r\n".parse::<HeaderValue>().is_err());
770 }
771
772 #[test]
773 fn valid_header_name() {
774 let valid_header_name = "X-Custom-Header";
775 let parsed_header = HeaderName::from_str(valid_header_name);
776
777 assert!(
778 parsed_header.is_ok(),
779 "Expected Ok(HeaderName), but got an error: {:?}",
780 parsed_header.err()
781 );
782 }
783
784 #[test]
785 fn dollar_header_name() {
786 let valid_header_name = "$X_Custom_Header";
787 let parsed_header = HeaderName::from_str(valid_header_name);
788
789 assert!(
790 parsed_header.is_ok(),
791 "Expected Ok(HeaderName), but got an error: {:?}",
792 parsed_header.err()
793 );
794 }
795
796 #[test]
797 fn invalid_header_name_with_space() {
798 let invalid_header_name = "X Custom Header";
799 let parsed_header = HeaderName::from_str(invalid_header_name);
800
801 assert!(
802 parsed_header.is_err(),
803 "Expected Err(InvalidHeaderNameError), but got Ok: {:?}",
804 parsed_header.ok()
805 );
806 }
807
808 #[test]
809 fn invalid_header_name_with_special_chars() {
810 let invalid_header_name = "X-Header:";
811 let parsed_header = HeaderName::from_str(invalid_header_name);
812
813 assert!(
814 parsed_header.is_err(),
815 "Expected Err(InvalidHeaderNameError), but got Ok: {:?}",
816 parsed_header.ok()
817 );
818 }
819
820 #[test]
821 fn from_static_eq() {
822 let a = HeaderName::from_static("NATS-Stream");
823 let b = HeaderName::from_static("NATS-Stream");
824
825 assert_eq!(a, b);
826 }
827
828 #[test]
829 fn header_name_serde() {
830 let raw = "Nats-Stream";
831 let raw_json = "\"Nats-Stream\"";
832 let header = HeaderName::from_static(raw);
833
834 assert_eq!(serde_json::to_string(&header).unwrap(), raw_json);
836 assert_eq!(
837 serde_json::from_str::<HeaderName>(raw_json).unwrap(),
838 header
839 );
840 }
841
842 #[test]
843 fn header_name_from_string() {
844 let string = "NATS-Stream".to_string();
845 let name = string.into_header_name();
846
847 assert_eq!("NATS-Stream", name.as_str());
848 }
849
850 #[test]
851 fn header_value_from_string_with_trait() {
852 let string = "some value".to_string();
853
854 let value = string.into_header_value();
855
856 assert_eq!("some value", value.as_str());
857 }
858
859 #[test]
860 fn header_value_from_string() {
861 let string = "some value".to_string();
862
863 let value: HeaderValue = string.into();
864
865 assert_eq!("some value", value.as_str());
866 }
867}