surrealdb_core/dbs/
capabilities.rs

1use std::collections::HashSet;
2use std::fmt;
3use std::hash::Hash;
4use std::net::IpAddr;
5
6use crate::iam::{Auth, Level};
7use crate::rpc::Method;
8use ipnet::IpNet;
9use url::Url;
10
11pub trait Target<Item: ?Sized = Self> {
12	fn matches(&self, elem: &Item) -> bool;
13}
14
15#[derive(Debug, Clone, Hash, Eq, PartialEq)]
16#[non_exhaustive]
17pub struct FuncTarget(pub String, pub Option<String>);
18
19impl fmt::Display for FuncTarget {
20	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21		match &self.1 {
22			Some(name) => write!(f, "{}:{}", self.0, name),
23			None => write!(f, "{}::*", self.0),
24		}
25	}
26}
27
28impl Target for FuncTarget {
29	fn matches(&self, elem: &FuncTarget) -> bool {
30		match self {
31			Self(family, Some(name)) => {
32				family == &elem.0 && (elem.1.as_ref().is_some_and(|n| n == name))
33			}
34			Self(family, None) => family == &elem.0,
35		}
36	}
37}
38
39impl Target<str> for FuncTarget {
40	fn matches(&self, elem: &str) -> bool {
41		if let Some(x) = self.1.as_ref() {
42			let Some((f, r)) = elem.split_once("::") else {
43				return false;
44			};
45
46			f == self.0 && r == x
47		} else {
48			let f = elem.split_once("::").map(|(f, _)| f).unwrap_or(elem);
49			f == self.0
50		}
51	}
52}
53
54#[derive(Debug, Clone)]
55pub enum ParseFuncTargetError {
56	InvalidWildcardFamily,
57	InvalidName,
58}
59
60impl std::error::Error for ParseFuncTargetError {}
61impl fmt::Display for ParseFuncTargetError {
62	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63		match *self {
64			ParseFuncTargetError::InvalidName => {
65				write!(f, "invalid function target name")
66			}
67			ParseFuncTargetError::InvalidWildcardFamily => {
68				write!(
69					f,
70					"invalid function target wildcard family, only first part of function can be wildcarded"
71				)
72			}
73		}
74	}
75}
76
77impl std::str::FromStr for FuncTarget {
78	type Err = ParseFuncTargetError;
79
80	fn from_str(s: &str) -> Result<Self, Self::Err> {
81		let s = s.trim();
82
83		if s.is_empty() {
84			return Err(ParseFuncTargetError::InvalidName);
85		}
86
87		if let Some(family) = s.strip_suffix("::*") {
88			if family.contains("::") {
89				return Err(ParseFuncTargetError::InvalidWildcardFamily);
90			}
91
92			if !family.bytes().all(|x| x.is_ascii_alphanumeric()) {
93				return Err(ParseFuncTargetError::InvalidName);
94			}
95
96			return Ok(FuncTarget(family.to_string(), None));
97		}
98
99		if !s.bytes().all(|x| x.is_ascii_alphanumeric() || x == b':') {
100			return Err(ParseFuncTargetError::InvalidName);
101		}
102
103		if let Some((first, rest)) = s.split_once("::") {
104			Ok(FuncTarget(first.to_string(), Some(rest.to_string())))
105		} else {
106			Ok(FuncTarget(s.to_string(), None))
107		}
108	}
109}
110
111#[derive(Debug, Clone, Hash, Eq, PartialEq)]
112#[non_exhaustive]
113pub enum ExperimentalTarget {
114	RecordReferences,
115	GraphQL,
116	BearerAccess,
117	DefineApi,
118}
119
120impl fmt::Display for ExperimentalTarget {
121	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122		match self {
123			Self::RecordReferences => write!(f, "record_references"),
124			Self::GraphQL => write!(f, "graphql"),
125			Self::BearerAccess => write!(f, "bearer_access"),
126			Self::DefineApi => write!(f, "define_api"),
127		}
128	}
129}
130
131impl Target for ExperimentalTarget {
132	fn matches(&self, elem: &ExperimentalTarget) -> bool {
133		self == elem
134	}
135}
136
137impl Target<str> for ExperimentalTarget {
138	fn matches(&self, elem: &str) -> bool {
139		match self {
140			Self::RecordReferences => elem.eq_ignore_ascii_case("record_references"),
141			Self::GraphQL => elem.eq_ignore_ascii_case("graphql"),
142			Self::BearerAccess => elem.eq_ignore_ascii_case("bearer_access"),
143			Self::DefineApi => elem.eq_ignore_ascii_case("define_api"),
144		}
145	}
146}
147
148#[derive(Debug, Clone)]
149pub enum ParseExperimentalTargetError {
150	InvalidName,
151}
152
153impl std::error::Error for ParseExperimentalTargetError {}
154impl fmt::Display for ParseExperimentalTargetError {
155	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156		match *self {
157			ParseExperimentalTargetError::InvalidName => {
158				write!(f, "invalid experimental target name")
159			}
160		}
161	}
162}
163
164impl std::str::FromStr for ExperimentalTarget {
165	type Err = ParseExperimentalTargetError;
166
167	fn from_str(s: &str) -> Result<Self, Self::Err> {
168		match_insensitive!(s.trim(), {
169			"record_references" => Ok(ExperimentalTarget::RecordReferences),
170			"graphql" => Ok(ExperimentalTarget::GraphQL),
171			"bearer_access" => Ok(ExperimentalTarget::BearerAccess),
172			"define_api" => Ok(ExperimentalTarget::DefineApi),
173			_ => Err(ParseExperimentalTargetError::InvalidName),
174		})
175	}
176}
177
178#[derive(Debug, Clone, Hash, Eq, PartialEq)]
179#[non_exhaustive]
180pub enum NetTarget {
181	Host(url::Host<String>, Option<u16>),
182	IPNet(ipnet::IpNet),
183}
184
185// impl display
186impl fmt::Display for NetTarget {
187	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188		match self {
189			Self::Host(host, Some(port)) => write!(f, "{}:{}", host, port),
190			Self::Host(host, None) => write!(f, "{}", host),
191			Self::IPNet(ipnet) => write!(f, "{}", ipnet),
192		}
193	}
194}
195
196impl Target for NetTarget {
197	fn matches(&self, elem: &Self) -> bool {
198		match self {
199			// If self contains a host and port, the elem must match both the host and port
200			Self::Host(host, Some(port)) => match elem {
201				Self::Host(_host, Some(_port)) => host == _host && port == _port,
202				_ => false,
203			},
204			// If self contains a host but no port, the elem must match the host only
205			Self::Host(host, None) => match elem {
206				Self::Host(_host, _) => host == _host,
207				_ => false,
208			},
209			// If self is an IPNet, it can match both an IPNet or a Host elem that contains an IPAddr
210			Self::IPNet(ipnet) => match elem {
211				Self::IPNet(_ipnet) => ipnet.contains(_ipnet),
212				Self::Host(host, _) => match host {
213					url::Host::Ipv4(ip) => ipnet.contains(&IpAddr::from(ip.to_owned())),
214					url::Host::Ipv6(ip) => ipnet.contains(&IpAddr::from(ip.to_owned())),
215					_ => false,
216				},
217			},
218		}
219	}
220}
221
222#[derive(Debug)]
223pub struct ParseNetTargetError;
224
225impl std::error::Error for ParseNetTargetError {}
226impl fmt::Display for ParseNetTargetError {
227	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228		write!(f, "The provided network target is not a valid host name, IP address or CIDR block")
229	}
230}
231
232impl std::str::FromStr for NetTarget {
233	type Err = ParseNetTargetError;
234
235	fn from_str(s: &str) -> Result<Self, Self::Err> {
236		// If it's a valid IPNet, return it
237		if let Ok(ipnet) = s.parse::<IpNet>() {
238			return Ok(NetTarget::IPNet(ipnet));
239		}
240
241		// If it's a valid IPAddr, return it as an IPNet
242		if let Ok(ipnet) = s.parse::<IpAddr>() {
243			return Ok(NetTarget::IPNet(IpNet::from(ipnet)));
244		}
245
246		// Parse the host and port parts from a string in the form of 'host' or 'host:port'
247		if let Ok(url) = Url::parse(format!("http://{s}").as_str()) {
248			if let Some(host) = url.host() {
249				// Url::parse will return port=None if the provided port was 80 (given we are using the http scheme). Get the original port from the string.
250				if let Some(Ok(port)) = s.split(':').next_back().map(|p| p.parse::<u16>()) {
251					return Ok(NetTarget::Host(host.to_owned(), Some(port)));
252				} else {
253					return Ok(NetTarget::Host(host.to_owned(), None));
254				}
255			}
256		}
257
258		Err(ParseNetTargetError)
259	}
260}
261
262#[derive(Debug, Clone, Hash, Eq, PartialEq)]
263pub struct MethodTarget {
264	pub method: Method,
265}
266
267// impl display
268impl fmt::Display for MethodTarget {
269	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270		write!(f, "{}", self.method.to_str())
271	}
272}
273
274impl Target for MethodTarget {
275	fn matches(&self, elem: &Self) -> bool {
276		self.method == elem.method
277	}
278}
279
280#[derive(Debug)]
281pub struct ParseMethodTargetError;
282
283impl std::error::Error for ParseMethodTargetError {}
284impl fmt::Display for ParseMethodTargetError {
285	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286		write!(f, "The provided method target is not a valid RPC method")
287	}
288}
289
290impl std::str::FromStr for MethodTarget {
291	type Err = ParseMethodTargetError;
292
293	fn from_str(s: &str) -> Result<Self, Self::Err> {
294		match Method::parse_case_insensitive(s) {
295			Method::Unknown => Err(ParseMethodTargetError),
296			method => Ok(MethodTarget {
297				method,
298			}),
299		}
300	}
301}
302
303#[derive(Debug, Clone, Hash, Eq, PartialEq)]
304#[non_exhaustive]
305pub enum RouteTarget {
306	Health,
307	Export,
308	Import,
309	Rpc,
310	Version,
311	Sync,
312	Sql,
313	Signin,
314	Signup,
315	Key,
316	Ml,
317	GraphQL,
318	Api,
319}
320
321// impl display
322impl fmt::Display for RouteTarget {
323	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
324		match self {
325			RouteTarget::Health => write!(f, "health"),
326			RouteTarget::Export => write!(f, "export"),
327			RouteTarget::Import => write!(f, "import"),
328			RouteTarget::Rpc => write!(f, "rpc"),
329			RouteTarget::Version => write!(f, "version"),
330			RouteTarget::Sync => write!(f, "sync"),
331			RouteTarget::Sql => write!(f, "sql"),
332			RouteTarget::Signin => write!(f, "signin"),
333			RouteTarget::Signup => write!(f, "signup"),
334			RouteTarget::Key => write!(f, "key"),
335			RouteTarget::Ml => write!(f, "ml"),
336			RouteTarget::GraphQL => write!(f, "graphql"),
337			RouteTarget::Api => write!(f, "api"),
338		}
339	}
340}
341
342impl Target for RouteTarget {
343	fn matches(&self, elem: &Self) -> bool {
344		*self == *elem
345	}
346}
347
348#[derive(Debug)]
349pub struct ParseRouteTargetError;
350
351impl std::error::Error for ParseRouteTargetError {}
352impl fmt::Display for ParseRouteTargetError {
353	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354		write!(f, "The provided route target is not a valid HTTP route")
355	}
356}
357
358impl std::str::FromStr for RouteTarget {
359	type Err = ParseRouteTargetError;
360
361	fn from_str(s: &str) -> Result<Self, Self::Err> {
362		match s {
363			"health" => Ok(RouteTarget::Health),
364			"export" => Ok(RouteTarget::Export),
365			"import" => Ok(RouteTarget::Import),
366			"rpc" => Ok(RouteTarget::Rpc),
367			"version" => Ok(RouteTarget::Version),
368			"sync" => Ok(RouteTarget::Sync),
369			"sql" => Ok(RouteTarget::Sql),
370			"signin" => Ok(RouteTarget::Signin),
371			"signup" => Ok(RouteTarget::Signup),
372			"key" => Ok(RouteTarget::Key),
373			"ml" => Ok(RouteTarget::Ml),
374			"graphql" => Ok(RouteTarget::GraphQL),
375			"api" => Ok(RouteTarget::Api),
376			_ => Err(ParseRouteTargetError),
377		}
378	}
379}
380
381#[derive(Debug, Clone, Hash, Eq, PartialEq)]
382#[non_exhaustive]
383pub enum ArbitraryQueryTarget {
384	Guest,
385	Record,
386	System,
387}
388
389impl fmt::Display for ArbitraryQueryTarget {
390	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
391		match self {
392			Self::Guest => write!(f, "guest"),
393			Self::Record => write!(f, "record"),
394			Self::System => write!(f, "system"),
395		}
396	}
397}
398
399impl<'a> From<&'a Level> for ArbitraryQueryTarget {
400	fn from(level: &'a Level) -> Self {
401		match level {
402			Level::No => ArbitraryQueryTarget::Guest,
403			Level::Root => ArbitraryQueryTarget::System,
404			Level::Namespace(_) => ArbitraryQueryTarget::System,
405			Level::Database(_, _) => ArbitraryQueryTarget::System,
406			Level::Record(_, _, _) => ArbitraryQueryTarget::Record,
407		}
408	}
409}
410
411impl<'a> From<&'a Auth> for ArbitraryQueryTarget {
412	fn from(auth: &'a Auth) -> Self {
413		auth.level().into()
414	}
415}
416
417impl Target for ArbitraryQueryTarget {
418	fn matches(&self, elem: &ArbitraryQueryTarget) -> bool {
419		self == elem
420	}
421}
422
423impl Target<str> for ArbitraryQueryTarget {
424	fn matches(&self, elem: &str) -> bool {
425		match self {
426			Self::Guest => elem.eq_ignore_ascii_case("guest"),
427			Self::Record => elem.eq_ignore_ascii_case("record"),
428			Self::System => elem.eq_ignore_ascii_case("system"),
429		}
430	}
431}
432
433#[derive(Debug, Clone)]
434pub enum ParseArbitraryQueryTargetError {
435	InvalidName,
436}
437
438impl std::error::Error for ParseArbitraryQueryTargetError {}
439impl fmt::Display for ParseArbitraryQueryTargetError {
440	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441		match *self {
442			ParseArbitraryQueryTargetError::InvalidName => {
443				write!(f, "invalid query target name")
444			}
445		}
446	}
447}
448
449impl std::str::FromStr for ArbitraryQueryTarget {
450	type Err = ParseArbitraryQueryTargetError;
451
452	fn from_str(s: &str) -> Result<Self, Self::Err> {
453		match_insensitive!(s.trim(), {
454			"guest" => Ok(ArbitraryQueryTarget::Guest),
455			"record" => Ok(ArbitraryQueryTarget::Record),
456			"system" => Ok(ArbitraryQueryTarget::System),
457			_ => Err(ParseArbitraryQueryTargetError::InvalidName),
458		})
459	}
460}
461
462#[derive(Debug, Clone, Eq, PartialEq)]
463#[non_exhaustive]
464pub enum Targets<T: Hash + Eq + PartialEq> {
465	None,
466	Some(HashSet<T>),
467	All,
468}
469
470impl<T: Target + Hash + Eq + PartialEq> From<T> for Targets<T> {
471	fn from(t: T) -> Self {
472		let mut set = HashSet::new();
473		set.insert(t);
474		Self::Some(set)
475	}
476}
477
478impl<T: Hash + Eq + PartialEq + fmt::Debug + fmt::Display> Targets<T> {
479	pub(crate) fn matches<S>(&self, elem: &S) -> bool
480	where
481		S: ?Sized,
482		T: Target<S>,
483	{
484		match self {
485			Self::None => false,
486			Self::All => true,
487			Self::Some(targets) => targets.iter().any(|t| t.matches(elem)),
488		}
489	}
490}
491
492impl<T: Target + Hash + Eq + PartialEq + fmt::Display> fmt::Display for Targets<T> {
493	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
494		match self {
495			Self::None => write!(f, "none"),
496			Self::All => write!(f, "all"),
497			Self::Some(targets) => {
498				let targets =
499					targets.iter().map(|t| t.to_string()).collect::<Vec<String>>().join(", ");
500				write!(f, "{}", targets)
501			}
502		}
503	}
504}
505
506#[derive(Debug, Clone)]
507#[non_exhaustive]
508pub struct Capabilities {
509	scripting: bool,
510	guest_access: bool,
511	live_query_notifications: bool,
512
513	allow_funcs: Targets<FuncTarget>,
514	deny_funcs: Targets<FuncTarget>,
515	allow_net: Targets<NetTarget>,
516	deny_net: Targets<NetTarget>,
517	allow_rpc: Targets<MethodTarget>,
518	deny_rpc: Targets<MethodTarget>,
519	allow_http: Targets<RouteTarget>,
520	deny_http: Targets<RouteTarget>,
521	allow_experimental: Targets<ExperimentalTarget>,
522	deny_experimental: Targets<ExperimentalTarget>,
523	allow_arbitrary_query: Targets<ArbitraryQueryTarget>,
524	deny_arbitrary_query: Targets<ArbitraryQueryTarget>,
525}
526
527impl fmt::Display for Capabilities {
528	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
529		write!(
530            f,
531            "scripting={}, guest_access={}, live_query_notifications={}, allow_funcs={}, deny_funcs={}, allow_net={}, deny_net={}, allow_rpc={}, deny_rpc={}, allow_http={}, deny_http={}, allow_experimental={}, deny_experimental={}, allow_arbitrary_query={}, deny_arbitrary_query={}",
532            self.scripting, self.guest_access, self.live_query_notifications, self.allow_funcs, self.deny_funcs, self.allow_net, self.deny_net, self.allow_rpc, self.deny_rpc, self.allow_http, self.deny_http, self.allow_experimental, self.deny_experimental, self.allow_arbitrary_query, self.deny_arbitrary_query,
533        )
534	}
535}
536
537impl Default for Capabilities {
538	fn default() -> Self {
539		Self {
540			scripting: false,
541			guest_access: false,
542			live_query_notifications: true,
543
544			allow_funcs: Targets::All,
545			deny_funcs: Targets::None,
546			allow_net: Targets::None,
547			deny_net: Targets::None,
548			allow_rpc: Targets::All,
549			deny_rpc: Targets::None,
550			allow_http: Targets::All,
551			deny_http: Targets::None,
552			allow_experimental: Targets::None,
553			deny_experimental: Targets::None,
554			allow_arbitrary_query: Targets::All,
555			deny_arbitrary_query: Targets::None,
556		}
557	}
558}
559
560impl Capabilities {
561	pub fn all() -> Self {
562		Self {
563			scripting: true,
564			guest_access: true,
565			live_query_notifications: true,
566
567			allow_funcs: Targets::All,
568			deny_funcs: Targets::None,
569			allow_net: Targets::All,
570			deny_net: Targets::None,
571			allow_rpc: Targets::All,
572			deny_rpc: Targets::None,
573			allow_http: Targets::All,
574			deny_http: Targets::None,
575			allow_experimental: Targets::None,
576			deny_experimental: Targets::None,
577			allow_arbitrary_query: Targets::All,
578			deny_arbitrary_query: Targets::None,
579		}
580	}
581
582	pub fn none() -> Self {
583		Self {
584			scripting: false,
585			guest_access: false,
586			live_query_notifications: false,
587
588			allow_funcs: Targets::None,
589			deny_funcs: Targets::None,
590			allow_net: Targets::None,
591			deny_net: Targets::None,
592			allow_rpc: Targets::None,
593			deny_rpc: Targets::None,
594			allow_http: Targets::None,
595			deny_http: Targets::None,
596			allow_experimental: Targets::None,
597			deny_experimental: Targets::None,
598			allow_arbitrary_query: Targets::None,
599			deny_arbitrary_query: Targets::None,
600		}
601	}
602
603	pub fn with_scripting(mut self, scripting: bool) -> Self {
604		self.scripting = scripting;
605		self
606	}
607
608	pub fn with_guest_access(mut self, guest_access: bool) -> Self {
609		self.guest_access = guest_access;
610		self
611	}
612
613	pub fn with_live_query_notifications(mut self, live_query_notifications: bool) -> Self {
614		self.live_query_notifications = live_query_notifications;
615		self
616	}
617
618	pub fn with_functions(mut self, allow_funcs: Targets<FuncTarget>) -> Self {
619		self.allow_funcs = allow_funcs;
620		self
621	}
622
623	pub fn allowed_functions_mut(&mut self) -> &mut Targets<FuncTarget> {
624		&mut self.allow_funcs
625	}
626
627	pub fn without_functions(mut self, deny_funcs: Targets<FuncTarget>) -> Self {
628		self.deny_funcs = deny_funcs;
629		self
630	}
631
632	pub fn denied_functions_mut(&mut self) -> &mut Targets<FuncTarget> {
633		&mut self.deny_funcs
634	}
635
636	pub fn with_experimental(mut self, allow_experimental: Targets<ExperimentalTarget>) -> Self {
637		self.allow_experimental = allow_experimental;
638		self
639	}
640
641	pub fn allowed_experimental_features_mut(&mut self) -> &mut Targets<ExperimentalTarget> {
642		&mut self.allow_experimental
643	}
644
645	pub fn without_experimental(mut self, deny_experimental: Targets<ExperimentalTarget>) -> Self {
646		self.deny_experimental = deny_experimental;
647		self
648	}
649
650	pub fn denied_experimental_features_mut(&mut self) -> &mut Targets<ExperimentalTarget> {
651		&mut self.deny_experimental
652	}
653
654	pub fn with_arbitrary_query(
655		mut self,
656		allow_arbitrary_query: Targets<ArbitraryQueryTarget>,
657	) -> Self {
658		self.allow_arbitrary_query = allow_arbitrary_query;
659		self
660	}
661
662	pub fn without_arbitrary_query(
663		mut self,
664		deny_arbitrary_query: Targets<ArbitraryQueryTarget>,
665	) -> Self {
666		self.deny_arbitrary_query = deny_arbitrary_query;
667		self
668	}
669
670	pub fn with_network_targets(mut self, allow_net: Targets<NetTarget>) -> Self {
671		self.allow_net = allow_net;
672		self
673	}
674
675	pub fn allowed_network_targets_mut(&mut self) -> &mut Targets<NetTarget> {
676		&mut self.allow_net
677	}
678
679	pub fn without_network_targets(mut self, deny_net: Targets<NetTarget>) -> Self {
680		self.deny_net = deny_net;
681		self
682	}
683
684	pub fn denied_network_targets_mut(&mut self) -> &mut Targets<NetTarget> {
685		&mut self.deny_net
686	}
687
688	pub fn with_rpc_methods(mut self, allow_rpc: Targets<MethodTarget>) -> Self {
689		self.allow_rpc = allow_rpc;
690		self
691	}
692
693	pub fn without_rpc_methods(mut self, deny_rpc: Targets<MethodTarget>) -> Self {
694		self.deny_rpc = deny_rpc;
695		self
696	}
697
698	pub fn with_http_routes(mut self, allow_http: Targets<RouteTarget>) -> Self {
699		self.allow_http = allow_http;
700		self
701	}
702
703	pub fn without_http_routes(mut self, deny_http: Targets<RouteTarget>) -> Self {
704		self.deny_http = deny_http;
705		self
706	}
707
708	pub fn allows_scripting(&self) -> bool {
709		self.scripting
710	}
711
712	pub fn allows_guest_access(&self) -> bool {
713		self.guest_access
714	}
715
716	pub fn allows_live_query_notifications(&self) -> bool {
717		self.live_query_notifications
718	}
719
720	pub fn allows_function_name(&self, target: &str) -> bool {
721		self.allow_funcs.matches(target) && !self.deny_funcs.matches(target)
722	}
723
724	pub fn allows_experimental(&self, target: &ExperimentalTarget) -> bool {
725		self.allow_experimental.matches(target) && !self.deny_experimental.matches(target)
726	}
727
728	pub fn allows_experimental_name(&self, target: &str) -> bool {
729		self.allow_experimental.matches(target) && !self.deny_experimental.matches(target)
730	}
731
732	pub fn allows_query(&self, target: &ArbitraryQueryTarget) -> bool {
733		self.allow_arbitrary_query.matches(target) && !self.deny_arbitrary_query.matches(target)
734	}
735
736	pub fn allows_query_name(&self, target: &str) -> bool {
737		self.allow_arbitrary_query.matches(target) && !self.deny_arbitrary_query.matches(target)
738	}
739
740	pub fn allows_network_target(&self, target: &NetTarget) -> bool {
741		self.allow_net.matches(target) && !self.deny_net.matches(target)
742	}
743
744	pub fn allows_rpc_method(&self, target: &MethodTarget) -> bool {
745		self.allow_rpc.matches(target) && !self.deny_rpc.matches(target)
746	}
747
748	pub fn allows_http_route(&self, target: &RouteTarget) -> bool {
749		self.allow_http.matches(target) && !self.deny_http.matches(target)
750	}
751}
752
753#[cfg(test)]
754mod tests {
755	use std::str::FromStr;
756	use test_log::test;
757
758	use super::*;
759
760	#[test]
761	fn test_invalid_func_target() {
762		FuncTarget::from_str("te::*st").unwrap_err();
763		FuncTarget::from_str("\0::st").unwrap_err();
764		FuncTarget::from_str("").unwrap_err();
765		FuncTarget::from_str("❤️").unwrap_err();
766	}
767
768	#[test]
769	fn test_func_target() {
770		assert!(FuncTarget::from_str("test").unwrap().matches("test"));
771		assert!(!FuncTarget::from_str("test").unwrap().matches("test2"));
772
773		assert!(!FuncTarget::from_str("test::").unwrap().matches("test"));
774
775		assert!(FuncTarget::from_str("test::*").unwrap().matches("test::name"));
776		assert!(!FuncTarget::from_str("test::*").unwrap().matches("test2::name"));
777
778		assert!(FuncTarget::from_str("test::name").unwrap().matches("test::name"));
779		assert!(!FuncTarget::from_str("test::name").unwrap().matches("test::name2"));
780	}
781
782	#[test]
783	fn test_net_target() {
784		// IPNet IPv4
785		assert!(NetTarget::from_str("10.0.0.0/8")
786			.unwrap()
787			.matches(&NetTarget::from_str("10.0.1.0/24").unwrap()));
788		assert!(NetTarget::from_str("10.0.0.0/8")
789			.unwrap()
790			.matches(&NetTarget::from_str("10.0.1.2").unwrap()));
791		assert!(!NetTarget::from_str("10.0.0.0/8")
792			.unwrap()
793			.matches(&NetTarget::from_str("20.0.1.0/24").unwrap()));
794		assert!(!NetTarget::from_str("10.0.0.0/8")
795			.unwrap()
796			.matches(&NetTarget::from_str("20.0.1.0").unwrap()));
797
798		// IPNet IPv6
799		assert!(NetTarget::from_str("2001:db8::1")
800			.unwrap()
801			.matches(&NetTarget::from_str("2001:db8::1").unwrap()));
802		assert!(NetTarget::from_str("2001:db8::/32")
803			.unwrap()
804			.matches(&NetTarget::from_str("2001:db8::1").unwrap()));
805		assert!(NetTarget::from_str("2001:db8::/32")
806			.unwrap()
807			.matches(&NetTarget::from_str("2001:db8:abcd:12::/64").unwrap()));
808		assert!(!NetTarget::from_str("2001:db8::/32")
809			.unwrap()
810			.matches(&NetTarget::from_str("2001:db9::1").unwrap()));
811		assert!(!NetTarget::from_str("2001:db8::/32")
812			.unwrap()
813			.matches(&NetTarget::from_str("2001:db9:abcd:12::1/64").unwrap()));
814
815		// Host domain with and without port
816		assert!(NetTarget::from_str("example.com")
817			.unwrap()
818			.matches(&NetTarget::from_str("example.com").unwrap()));
819		assert!(NetTarget::from_str("example.com")
820			.unwrap()
821			.matches(&NetTarget::from_str("example.com:80").unwrap()));
822		assert!(!NetTarget::from_str("example.com")
823			.unwrap()
824			.matches(&NetTarget::from_str("www.example.com").unwrap()));
825		assert!(!NetTarget::from_str("example.com")
826			.unwrap()
827			.matches(&NetTarget::from_str("www.example.com:80").unwrap()));
828		assert!(NetTarget::from_str("example.com:80")
829			.unwrap()
830			.matches(&NetTarget::from_str("example.com:80").unwrap()));
831		assert!(!NetTarget::from_str("example.com:80")
832			.unwrap()
833			.matches(&NetTarget::from_str("example.com:443").unwrap()));
834		assert!(!NetTarget::from_str("example.com:80")
835			.unwrap()
836			.matches(&NetTarget::from_str("example.com").unwrap()));
837
838		// Host IPv4 with and without port
839		assert!(
840			NetTarget::from_str("127.0.0.1")
841				.unwrap()
842				.matches(&NetTarget::from_str("127.0.0.1").unwrap()),
843			"Host IPv4 without port matches itself"
844		);
845		assert!(
846			NetTarget::from_str("127.0.0.1")
847				.unwrap()
848				.matches(&NetTarget::from_str("127.0.0.1:80").unwrap()),
849			"Host IPv4 without port matches Host IPv4 with port"
850		);
851		assert!(
852			NetTarget::from_str("10.0.0.0/8")
853				.unwrap()
854				.matches(&NetTarget::from_str("10.0.0.1:80").unwrap()),
855			"IPv4 network matches Host IPv4 with port"
856		);
857		assert!(
858			NetTarget::from_str("127.0.0.1:80")
859				.unwrap()
860				.matches(&NetTarget::from_str("127.0.0.1:80").unwrap()),
861			"Host IPv4 with port matches itself"
862		);
863		assert!(
864			!NetTarget::from_str("127.0.0.1:80")
865				.unwrap()
866				.matches(&NetTarget::from_str("127.0.0.1").unwrap()),
867			"Host IPv4 with port does not match Host IPv4 without port"
868		);
869		assert!(
870			!NetTarget::from_str("127.0.0.1:80")
871				.unwrap()
872				.matches(&NetTarget::from_str("127.0.0.1:443").unwrap()),
873			"Host IPv4 with port does not match Host IPv4 with different port"
874		);
875
876		// Host IPv6 with and without port
877		assert!(
878			NetTarget::from_str("[2001:db8::1]")
879				.unwrap()
880				.matches(&NetTarget::from_str("[2001:db8::1]").unwrap()),
881			"Host IPv6 without port matches itself"
882		);
883		assert!(
884			NetTarget::from_str("[2001:db8::1]")
885				.unwrap()
886				.matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
887			"Host IPv6 without port matches Host IPv6 with port"
888		);
889		assert!(
890			NetTarget::from_str("2001:db8::1")
891				.unwrap()
892				.matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
893			"IPv6 addr matches Host IPv6 with port"
894		);
895		assert!(
896			NetTarget::from_str("2001:db8::/64")
897				.unwrap()
898				.matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
899			"IPv6 network matches Host IPv6 with port"
900		);
901		assert!(
902			NetTarget::from_str("[2001:db8::1]:80")
903				.unwrap()
904				.matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
905			"Host IPv6 with port matches itself"
906		);
907		assert!(
908			!NetTarget::from_str("[2001:db8::1]:80")
909				.unwrap()
910				.matches(&NetTarget::from_str("[2001:db8::1]").unwrap()),
911			"Host IPv6 with port does not match Host IPv6 without port"
912		);
913		assert!(
914			!NetTarget::from_str("[2001:db8::1]:80")
915				.unwrap()
916				.matches(&NetTarget::from_str("[2001:db8::1]:443").unwrap()),
917			"Host IPv6 with port does not match Host IPv6 with different port"
918		);
919
920		// Test invalid targets
921		assert!(NetTarget::from_str("exam^ple.com").is_err());
922		assert!(NetTarget::from_str("example.com:80:80").is_err());
923		assert!(NetTarget::from_str("11111.3.4.5").is_err());
924		assert!(NetTarget::from_str("2001:db8::1/129").is_err());
925		assert!(NetTarget::from_str("[2001:db8::1").is_err());
926	}
927
928	#[test]
929	fn test_method_target() {
930		assert!(MethodTarget::from_str("query")
931			.unwrap()
932			.matches(&MethodTarget::from_str("query").unwrap()));
933		assert!(MethodTarget::from_str("query")
934			.unwrap()
935			.matches(&MethodTarget::from_str("Query").unwrap()));
936		assert!(MethodTarget::from_str("query")
937			.unwrap()
938			.matches(&MethodTarget::from_str("QUERY").unwrap()));
939		assert!(!MethodTarget::from_str("query")
940			.unwrap()
941			.matches(&MethodTarget::from_str("ping").unwrap()));
942	}
943
944	#[test]
945	fn test_targets() {
946		assert!(Targets::<NetTarget>::All.matches(&NetTarget::from_str("example.com").unwrap()));
947		assert!(Targets::<FuncTarget>::All.matches("http::get"));
948		assert!(!Targets::<NetTarget>::None.matches(&NetTarget::from_str("example.com").unwrap()));
949		assert!(!Targets::<FuncTarget>::None.matches("http::get"));
950		assert!(Targets::<NetTarget>::Some([NetTarget::from_str("example.com").unwrap()].into())
951			.matches(&NetTarget::from_str("example.com").unwrap()));
952		assert!(!Targets::<NetTarget>::Some([NetTarget::from_str("example.com").unwrap()].into())
953			.matches(&NetTarget::from_str("www.example.com").unwrap()));
954		assert!(Targets::<FuncTarget>::Some([FuncTarget::from_str("http::get").unwrap()].into())
955			.matches("http::get"));
956		assert!(!Targets::<FuncTarget>::Some([FuncTarget::from_str("http::get").unwrap()].into())
957			.matches("http::post"));
958	}
959
960	#[test]
961	fn test_capabilities() {
962		// When scripting is allowed
963		{
964			let caps = Capabilities::default().with_scripting(true);
965			assert!(caps.allows_scripting());
966		}
967
968		// When scripting is denied
969		{
970			let caps = Capabilities::default().with_scripting(false);
971			assert!(!caps.allows_scripting());
972		}
973
974		// When guest access is allowed
975		{
976			let caps = Capabilities::default().with_guest_access(true);
977			assert!(caps.allows_guest_access());
978		}
979
980		// When guest access is denied
981		{
982			let caps = Capabilities::default().with_guest_access(false);
983			assert!(!caps.allows_guest_access());
984		}
985
986		// When live query notifications are allowed
987		{
988			let cap = Capabilities::default().with_live_query_notifications(true);
989			assert!(cap.allows_live_query_notifications());
990		}
991
992		// When live query notifications are disabled
993		{
994			let cap = Capabilities::default().with_live_query_notifications(false);
995			assert!(!cap.allows_live_query_notifications());
996		}
997
998		// When all nets are allowed
999		{
1000			let caps = Capabilities::default()
1001				.with_network_targets(Targets::<NetTarget>::All)
1002				.without_network_targets(Targets::<NetTarget>::None);
1003			assert!(caps.allows_network_target(&NetTarget::from_str("example.com").unwrap()));
1004			assert!(caps.allows_network_target(&NetTarget::from_str("example.com:80").unwrap()));
1005		}
1006
1007		// When all nets are allowed and denied at the same time
1008		{
1009			let caps = Capabilities::default()
1010				.with_network_targets(Targets::<NetTarget>::All)
1011				.without_network_targets(Targets::<NetTarget>::All);
1012			assert!(!caps.allows_network_target(&NetTarget::from_str("example.com").unwrap()));
1013			assert!(!caps.allows_network_target(&NetTarget::from_str("example.com:80").unwrap()));
1014		}
1015
1016		// When some nets are allowed and some are denied, deny overrides the allow rules
1017		{
1018			let caps = Capabilities::default()
1019				.with_network_targets(Targets::<NetTarget>::Some(
1020					[NetTarget::from_str("example.com").unwrap()].into(),
1021				))
1022				.without_network_targets(Targets::<NetTarget>::Some(
1023					[NetTarget::from_str("example.com:80").unwrap()].into(),
1024				));
1025			assert!(caps.allows_network_target(&NetTarget::from_str("example.com").unwrap()));
1026			assert!(caps.allows_network_target(&NetTarget::from_str("example.com:443").unwrap()));
1027			assert!(!caps.allows_network_target(&NetTarget::from_str("example.com:80").unwrap()));
1028		}
1029
1030		// When all funcs are allowed
1031		{
1032			let caps = Capabilities::default()
1033				.with_functions(Targets::<FuncTarget>::All)
1034				.without_functions(Targets::<FuncTarget>::None);
1035			assert!(caps.allows_function_name("http::get"));
1036			assert!(caps.allows_function_name("http::post"));
1037		}
1038
1039		// When all funcs are allowed and denied at the same time
1040		{
1041			let caps = Capabilities::default()
1042				.with_functions(Targets::<FuncTarget>::All)
1043				.without_functions(Targets::<FuncTarget>::All);
1044			assert!(!caps.allows_function_name("http::get"));
1045			assert!(!caps.allows_function_name("http::post"));
1046		}
1047
1048		// When some funcs are allowed and some are denied, deny overrides the allow rules
1049		{
1050			let caps = Capabilities::default()
1051				.with_functions(Targets::<FuncTarget>::Some(
1052					[FuncTarget::from_str("http::*").unwrap()].into(),
1053				))
1054				.without_functions(Targets::<FuncTarget>::Some(
1055					[FuncTarget::from_str("http::post").unwrap()].into(),
1056				));
1057			assert!(caps.allows_function_name("http::get"));
1058			assert!(caps.allows_function_name("http::put"));
1059			assert!(!caps.allows_function_name("http::post"));
1060		}
1061
1062		// When all RPC methods are allowed
1063		{
1064			let caps = Capabilities::default()
1065				.with_rpc_methods(Targets::<MethodTarget>::All)
1066				.without_rpc_methods(Targets::<MethodTarget>::None);
1067			assert!(caps.allows_rpc_method(&MethodTarget::from_str("ping").unwrap()));
1068			assert!(caps.allows_rpc_method(&MethodTarget::from_str("select").unwrap()));
1069			assert!(caps.allows_rpc_method(&MethodTarget::from_str("query").unwrap()));
1070		}
1071
1072		// When all RPC methods are allowed and denied at the same time
1073		{
1074			let caps = Capabilities::default()
1075				.with_rpc_methods(Targets::<MethodTarget>::All)
1076				.without_rpc_methods(Targets::<MethodTarget>::All);
1077			assert!(!caps.allows_rpc_method(&MethodTarget::from_str("ping").unwrap()));
1078			assert!(!caps.allows_rpc_method(&MethodTarget::from_str("select").unwrap()));
1079			assert!(!caps.allows_rpc_method(&MethodTarget::from_str("query").unwrap()));
1080		}
1081
1082		// When some RPC methods are allowed and some are denied, deny overrides the allow rules
1083		{
1084			let caps = Capabilities::default()
1085				.with_rpc_methods(Targets::<MethodTarget>::Some(
1086					[
1087						MethodTarget::from_str("select").unwrap(),
1088						MethodTarget::from_str("create").unwrap(),
1089						MethodTarget::from_str("insert").unwrap(),
1090						MethodTarget::from_str("update").unwrap(),
1091						MethodTarget::from_str("query").unwrap(),
1092						MethodTarget::from_str("run").unwrap(),
1093					]
1094					.into(),
1095				))
1096				.without_rpc_methods(Targets::<MethodTarget>::Some(
1097					[
1098						MethodTarget::from_str("query").unwrap(),
1099						MethodTarget::from_str("run").unwrap(),
1100					]
1101					.into(),
1102				));
1103
1104			assert!(caps.allows_rpc_method(&MethodTarget::from_str("select").unwrap()));
1105			assert!(caps.allows_rpc_method(&MethodTarget::from_str("create").unwrap()));
1106			assert!(caps.allows_rpc_method(&MethodTarget::from_str("insert").unwrap()));
1107			assert!(caps.allows_rpc_method(&MethodTarget::from_str("update").unwrap()));
1108			assert!(!caps.allows_rpc_method(&MethodTarget::from_str("query").unwrap()));
1109			assert!(!caps.allows_rpc_method(&MethodTarget::from_str("run").unwrap()));
1110		}
1111
1112		// When all HTTP routes are allowed
1113		{
1114			let caps = Capabilities::default()
1115				.with_http_routes(Targets::<RouteTarget>::All)
1116				.without_http_routes(Targets::<RouteTarget>::None);
1117			assert!(caps.allows_http_route(&RouteTarget::from_str("version").unwrap()));
1118			assert!(caps.allows_http_route(&RouteTarget::from_str("export").unwrap()));
1119			assert!(caps.allows_http_route(&RouteTarget::from_str("sql").unwrap()));
1120		}
1121
1122		// When all HTTP routes are allowed and denied at the same time
1123		{
1124			let caps = Capabilities::default()
1125				.with_http_routes(Targets::<RouteTarget>::All)
1126				.without_http_routes(Targets::<RouteTarget>::All);
1127			assert!(!caps.allows_http_route(&RouteTarget::from_str("version").unwrap()));
1128			assert!(!caps.allows_http_route(&RouteTarget::from_str("export").unwrap()));
1129			assert!(!caps.allows_http_route(&RouteTarget::from_str("sql").unwrap()));
1130		}
1131
1132		// When some HTTP rotues are allowed and some are denied, deny overrides the allow rules
1133		{
1134			let caps = Capabilities::default()
1135				.with_http_routes(Targets::<RouteTarget>::Some(
1136					[
1137						RouteTarget::from_str("version").unwrap(),
1138						RouteTarget::from_str("import").unwrap(),
1139						RouteTarget::from_str("export").unwrap(),
1140						RouteTarget::from_str("key").unwrap(),
1141						RouteTarget::from_str("sql").unwrap(),
1142						RouteTarget::from_str("rpc").unwrap(),
1143					]
1144					.into(),
1145				))
1146				.without_http_routes(Targets::<RouteTarget>::Some(
1147					[RouteTarget::from_str("sql").unwrap(), RouteTarget::from_str("rpc").unwrap()]
1148						.into(),
1149				));
1150
1151			assert!(caps.allows_http_route(&RouteTarget::from_str("version").unwrap()));
1152			assert!(caps.allows_http_route(&RouteTarget::from_str("import").unwrap()));
1153			assert!(caps.allows_http_route(&RouteTarget::from_str("export").unwrap()));
1154			assert!(caps.allows_http_route(&RouteTarget::from_str("key").unwrap()));
1155			assert!(!caps.allows_http_route(&RouteTarget::from_str("sql").unwrap()));
1156			assert!(!caps.allows_http_route(&RouteTarget::from_str("rpc").unwrap()));
1157		}
1158	}
1159}