surrealdb_core/sql/
strand.rs

1use crate::err::Error;
2use crate::sql::escape::quote_plain_str;
3use revision::revisioned;
4use serde::{Deserialize, Serialize};
5use std::fmt::{self, Display, Formatter};
6use std::ops::Deref;
7use std::ops::{self};
8use std::str;
9
10use super::value::TryAdd;
11
12pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Strand";
13
14/// A string that doesn't contain NUL bytes.
15#[revisioned(revision = 1)]
16#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash)]
17#[serde(rename = "$surrealdb::private::sql::Strand")]
18#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
19#[non_exhaustive]
20pub struct Strand(#[serde(with = "no_nul_bytes")] pub String);
21
22impl From<String> for Strand {
23	fn from(s: String) -> Self {
24		debug_assert!(!s.contains('\0'));
25		Strand(s)
26	}
27}
28
29impl From<&str> for Strand {
30	fn from(s: &str) -> Self {
31		debug_assert!(!s.contains('\0'));
32		Self::from(String::from(s))
33	}
34}
35
36impl Deref for Strand {
37	type Target = String;
38	fn deref(&self) -> &Self::Target {
39		&self.0
40	}
41}
42
43impl From<Strand> for String {
44	fn from(s: Strand) -> Self {
45		s.0
46	}
47}
48
49impl Strand {
50	/// Get the underlying String slice
51	pub fn as_str(&self) -> &str {
52		self.0.as_str()
53	}
54	/// Returns the underlying String
55	pub fn as_string(self) -> String {
56		self.0
57	}
58	/// Convert the Strand to a raw String
59	pub fn to_raw(self) -> String {
60		self.0
61	}
62}
63
64impl Display for Strand {
65	fn fmt(&self, f: &mut Formatter) -> fmt::Result {
66		Display::fmt(&quote_plain_str(&self.0), f)
67	}
68}
69
70impl ops::Add for Strand {
71	type Output = Self;
72	fn add(mut self, other: Self) -> Self {
73		self.0.push_str(other.as_str());
74		self
75	}
76}
77
78impl TryAdd for Strand {
79	type Output = Self;
80	fn try_add(mut self, other: Self) -> Result<Self, Error> {
81		if self.0.try_reserve(other.len()).is_ok() {
82			self.0.push_str(other.as_str());
83			Ok(self)
84		} else {
85			Err(Error::InsufficientReserve(format!(
86				"additional string of length {} bytes",
87				other.0.len()
88			)))
89		}
90	}
91}
92
93// serde(with = no_nul_bytes) will (de)serialize with no NUL bytes.
94pub(crate) mod no_nul_bytes {
95	use serde::{
96		de::{self, Visitor},
97		Deserializer, Serializer,
98	};
99	use std::fmt;
100
101	pub(crate) fn serialize<S>(s: &str, serializer: S) -> Result<S::Ok, S::Error>
102	where
103		S: Serializer,
104	{
105		debug_assert!(!s.contains('\0'));
106		serializer.serialize_str(s)
107	}
108
109	pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
110	where
111		D: Deserializer<'de>,
112	{
113		struct NoNulBytesVisitor;
114
115		impl Visitor<'_> for NoNulBytesVisitor {
116			type Value = String;
117
118			fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
119				formatter.write_str("a string without any NUL bytes")
120			}
121
122			fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
123			where
124				E: de::Error,
125			{
126				if value.contains('\0') {
127					Err(de::Error::custom("contained NUL byte"))
128				} else {
129					Ok(value.to_owned())
130				}
131			}
132
133			fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
134			where
135				E: de::Error,
136			{
137				if value.contains('\0') {
138					Err(de::Error::custom("contained NUL byte"))
139				} else {
140					Ok(value)
141				}
142			}
143		}
144
145		deserializer.deserialize_string(NoNulBytesVisitor)
146	}
147}