1use std::collections::HashSet;
2use std::fmt;
3use std::hash::Hash;
4use std::net::IpAddr;
5
6use crate::iam::{Auth, Level};
7use crate::rpc::Method;
8use ipnet::IpNet;
9use url::Url;
10
11pub trait Target<Item: ?Sized = Self> {
12 fn matches(&self, elem: &Item) -> bool;
13}
14
15#[derive(Debug, Clone, Hash, Eq, PartialEq)]
16#[non_exhaustive]
17pub struct FuncTarget(pub String, pub Option<String>);
18
19impl fmt::Display for FuncTarget {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 match &self.1 {
22 Some(name) => write!(f, "{}:{}", self.0, name),
23 None => write!(f, "{}::*", self.0),
24 }
25 }
26}
27
28impl Target for FuncTarget {
29 fn matches(&self, elem: &FuncTarget) -> bool {
30 match self {
31 Self(family, Some(name)) => {
32 family == &elem.0 && (elem.1.as_ref().is_some_and(|n| n == name))
33 }
34 Self(family, None) => family == &elem.0,
35 }
36 }
37}
38
39impl Target<str> for FuncTarget {
40 fn matches(&self, elem: &str) -> bool {
41 if let Some(x) = self.1.as_ref() {
42 let Some((f, r)) = elem.split_once("::") else {
43 return false;
44 };
45
46 f == self.0 && r == x
47 } else {
48 let f = elem.split_once("::").map(|(f, _)| f).unwrap_or(elem);
49 f == self.0
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
55pub enum ParseFuncTargetError {
56 InvalidWildcardFamily,
57 InvalidName,
58}
59
60impl std::error::Error for ParseFuncTargetError {}
61impl fmt::Display for ParseFuncTargetError {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 match *self {
64 ParseFuncTargetError::InvalidName => {
65 write!(f, "invalid function target name")
66 }
67 ParseFuncTargetError::InvalidWildcardFamily => {
68 write!(
69 f,
70 "invalid function target wildcard family, only first part of function can be wildcarded"
71 )
72 }
73 }
74 }
75}
76
77impl std::str::FromStr for FuncTarget {
78 type Err = ParseFuncTargetError;
79
80 fn from_str(s: &str) -> Result<Self, Self::Err> {
81 let s = s.trim();
82
83 if s.is_empty() {
84 return Err(ParseFuncTargetError::InvalidName);
85 }
86
87 if let Some(family) = s.strip_suffix("::*") {
88 if family.contains("::") {
89 return Err(ParseFuncTargetError::InvalidWildcardFamily);
90 }
91
92 if !family.bytes().all(|x| x.is_ascii_alphanumeric()) {
93 return Err(ParseFuncTargetError::InvalidName);
94 }
95
96 return Ok(FuncTarget(family.to_string(), None));
97 }
98
99 if !s.bytes().all(|x| x.is_ascii_alphanumeric() || x == b':') {
100 return Err(ParseFuncTargetError::InvalidName);
101 }
102
103 if let Some((first, rest)) = s.split_once("::") {
104 Ok(FuncTarget(first.to_string(), Some(rest.to_string())))
105 } else {
106 Ok(FuncTarget(s.to_string(), None))
107 }
108 }
109}
110
111#[derive(Debug, Clone, Hash, Eq, PartialEq)]
112#[non_exhaustive]
113pub enum ExperimentalTarget {
114 RecordReferences,
115 GraphQL,
116 BearerAccess,
117 DefineApi,
118}
119
120impl fmt::Display for ExperimentalTarget {
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 match self {
123 Self::RecordReferences => write!(f, "record_references"),
124 Self::GraphQL => write!(f, "graphql"),
125 Self::BearerAccess => write!(f, "bearer_access"),
126 Self::DefineApi => write!(f, "define_api"),
127 }
128 }
129}
130
131impl Target for ExperimentalTarget {
132 fn matches(&self, elem: &ExperimentalTarget) -> bool {
133 self == elem
134 }
135}
136
137impl Target<str> for ExperimentalTarget {
138 fn matches(&self, elem: &str) -> bool {
139 match self {
140 Self::RecordReferences => elem.eq_ignore_ascii_case("record_references"),
141 Self::GraphQL => elem.eq_ignore_ascii_case("graphql"),
142 Self::BearerAccess => elem.eq_ignore_ascii_case("bearer_access"),
143 Self::DefineApi => elem.eq_ignore_ascii_case("define_api"),
144 }
145 }
146}
147
148#[derive(Debug, Clone)]
149pub enum ParseExperimentalTargetError {
150 InvalidName,
151}
152
153impl std::error::Error for ParseExperimentalTargetError {}
154impl fmt::Display for ParseExperimentalTargetError {
155 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156 match *self {
157 ParseExperimentalTargetError::InvalidName => {
158 write!(f, "invalid experimental target name")
159 }
160 }
161 }
162}
163
164impl std::str::FromStr for ExperimentalTarget {
165 type Err = ParseExperimentalTargetError;
166
167 fn from_str(s: &str) -> Result<Self, Self::Err> {
168 match_insensitive!(s.trim(), {
169 "record_references" => Ok(ExperimentalTarget::RecordReferences),
170 "graphql" => Ok(ExperimentalTarget::GraphQL),
171 "bearer_access" => Ok(ExperimentalTarget::BearerAccess),
172 "define_api" => Ok(ExperimentalTarget::DefineApi),
173 _ => Err(ParseExperimentalTargetError::InvalidName),
174 })
175 }
176}
177
178#[derive(Debug, Clone, Hash, Eq, PartialEq)]
179#[non_exhaustive]
180pub enum NetTarget {
181 Host(url::Host<String>, Option<u16>),
182 IPNet(ipnet::IpNet),
183}
184
185impl fmt::Display for NetTarget {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 match self {
189 Self::Host(host, Some(port)) => write!(f, "{}:{}", host, port),
190 Self::Host(host, None) => write!(f, "{}", host),
191 Self::IPNet(ipnet) => write!(f, "{}", ipnet),
192 }
193 }
194}
195
196impl Target for NetTarget {
197 fn matches(&self, elem: &Self) -> bool {
198 match self {
199 Self::Host(host, Some(port)) => match elem {
201 Self::Host(_host, Some(_port)) => host == _host && port == _port,
202 _ => false,
203 },
204 Self::Host(host, None) => match elem {
206 Self::Host(_host, _) => host == _host,
207 _ => false,
208 },
209 Self::IPNet(ipnet) => match elem {
211 Self::IPNet(_ipnet) => ipnet.contains(_ipnet),
212 Self::Host(host, _) => match host {
213 url::Host::Ipv4(ip) => ipnet.contains(&IpAddr::from(ip.to_owned())),
214 url::Host::Ipv6(ip) => ipnet.contains(&IpAddr::from(ip.to_owned())),
215 _ => false,
216 },
217 },
218 }
219 }
220}
221
222#[derive(Debug)]
223pub struct ParseNetTargetError;
224
225impl std::error::Error for ParseNetTargetError {}
226impl fmt::Display for ParseNetTargetError {
227 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228 write!(f, "The provided network target is not a valid host name, IP address or CIDR block")
229 }
230}
231
232impl std::str::FromStr for NetTarget {
233 type Err = ParseNetTargetError;
234
235 fn from_str(s: &str) -> Result<Self, Self::Err> {
236 if let Ok(ipnet) = s.parse::<IpNet>() {
238 return Ok(NetTarget::IPNet(ipnet));
239 }
240
241 if let Ok(ipnet) = s.parse::<IpAddr>() {
243 return Ok(NetTarget::IPNet(IpNet::from(ipnet)));
244 }
245
246 if let Ok(url) = Url::parse(format!("http://{s}").as_str()) {
248 if let Some(host) = url.host() {
249 if let Some(Ok(port)) = s.split(':').next_back().map(|p| p.parse::<u16>()) {
251 return Ok(NetTarget::Host(host.to_owned(), Some(port)));
252 } else {
253 return Ok(NetTarget::Host(host.to_owned(), None));
254 }
255 }
256 }
257
258 Err(ParseNetTargetError)
259 }
260}
261
262#[derive(Debug, Clone, Hash, Eq, PartialEq)]
263pub struct MethodTarget {
264 pub method: Method,
265}
266
267impl fmt::Display for MethodTarget {
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 write!(f, "{}", self.method.to_str())
271 }
272}
273
274impl Target for MethodTarget {
275 fn matches(&self, elem: &Self) -> bool {
276 self.method == elem.method
277 }
278}
279
280#[derive(Debug)]
281pub struct ParseMethodTargetError;
282
283impl std::error::Error for ParseMethodTargetError {}
284impl fmt::Display for ParseMethodTargetError {
285 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286 write!(f, "The provided method target is not a valid RPC method")
287 }
288}
289
290impl std::str::FromStr for MethodTarget {
291 type Err = ParseMethodTargetError;
292
293 fn from_str(s: &str) -> Result<Self, Self::Err> {
294 match Method::parse_case_insensitive(s) {
295 Method::Unknown => Err(ParseMethodTargetError),
296 method => Ok(MethodTarget {
297 method,
298 }),
299 }
300 }
301}
302
303#[derive(Debug, Clone, Hash, Eq, PartialEq)]
304#[non_exhaustive]
305pub enum RouteTarget {
306 Health,
307 Export,
308 Import,
309 Rpc,
310 Version,
311 Sync,
312 Sql,
313 Signin,
314 Signup,
315 Key,
316 Ml,
317 GraphQL,
318 Api,
319}
320
321impl fmt::Display for RouteTarget {
323 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
324 match self {
325 RouteTarget::Health => write!(f, "health"),
326 RouteTarget::Export => write!(f, "export"),
327 RouteTarget::Import => write!(f, "import"),
328 RouteTarget::Rpc => write!(f, "rpc"),
329 RouteTarget::Version => write!(f, "version"),
330 RouteTarget::Sync => write!(f, "sync"),
331 RouteTarget::Sql => write!(f, "sql"),
332 RouteTarget::Signin => write!(f, "signin"),
333 RouteTarget::Signup => write!(f, "signup"),
334 RouteTarget::Key => write!(f, "key"),
335 RouteTarget::Ml => write!(f, "ml"),
336 RouteTarget::GraphQL => write!(f, "graphql"),
337 RouteTarget::Api => write!(f, "api"),
338 }
339 }
340}
341
342impl Target for RouteTarget {
343 fn matches(&self, elem: &Self) -> bool {
344 *self == *elem
345 }
346}
347
348#[derive(Debug)]
349pub struct ParseRouteTargetError;
350
351impl std::error::Error for ParseRouteTargetError {}
352impl fmt::Display for ParseRouteTargetError {
353 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354 write!(f, "The provided route target is not a valid HTTP route")
355 }
356}
357
358impl std::str::FromStr for RouteTarget {
359 type Err = ParseRouteTargetError;
360
361 fn from_str(s: &str) -> Result<Self, Self::Err> {
362 match s {
363 "health" => Ok(RouteTarget::Health),
364 "export" => Ok(RouteTarget::Export),
365 "import" => Ok(RouteTarget::Import),
366 "rpc" => Ok(RouteTarget::Rpc),
367 "version" => Ok(RouteTarget::Version),
368 "sync" => Ok(RouteTarget::Sync),
369 "sql" => Ok(RouteTarget::Sql),
370 "signin" => Ok(RouteTarget::Signin),
371 "signup" => Ok(RouteTarget::Signup),
372 "key" => Ok(RouteTarget::Key),
373 "ml" => Ok(RouteTarget::Ml),
374 "graphql" => Ok(RouteTarget::GraphQL),
375 "api" => Ok(RouteTarget::Api),
376 _ => Err(ParseRouteTargetError),
377 }
378 }
379}
380
381#[derive(Debug, Clone, Hash, Eq, PartialEq)]
382#[non_exhaustive]
383pub enum ArbitraryQueryTarget {
384 Guest,
385 Record,
386 System,
387}
388
389impl fmt::Display for ArbitraryQueryTarget {
390 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
391 match self {
392 Self::Guest => write!(f, "guest"),
393 Self::Record => write!(f, "record"),
394 Self::System => write!(f, "system"),
395 }
396 }
397}
398
399impl<'a> From<&'a Level> for ArbitraryQueryTarget {
400 fn from(level: &'a Level) -> Self {
401 match level {
402 Level::No => ArbitraryQueryTarget::Guest,
403 Level::Root => ArbitraryQueryTarget::System,
404 Level::Namespace(_) => ArbitraryQueryTarget::System,
405 Level::Database(_, _) => ArbitraryQueryTarget::System,
406 Level::Record(_, _, _) => ArbitraryQueryTarget::Record,
407 }
408 }
409}
410
411impl<'a> From<&'a Auth> for ArbitraryQueryTarget {
412 fn from(auth: &'a Auth) -> Self {
413 auth.level().into()
414 }
415}
416
417impl Target for ArbitraryQueryTarget {
418 fn matches(&self, elem: &ArbitraryQueryTarget) -> bool {
419 self == elem
420 }
421}
422
423impl Target<str> for ArbitraryQueryTarget {
424 fn matches(&self, elem: &str) -> bool {
425 match self {
426 Self::Guest => elem.eq_ignore_ascii_case("guest"),
427 Self::Record => elem.eq_ignore_ascii_case("record"),
428 Self::System => elem.eq_ignore_ascii_case("system"),
429 }
430 }
431}
432
433#[derive(Debug, Clone)]
434pub enum ParseArbitraryQueryTargetError {
435 InvalidName,
436}
437
438impl std::error::Error for ParseArbitraryQueryTargetError {}
439impl fmt::Display for ParseArbitraryQueryTargetError {
440 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441 match *self {
442 ParseArbitraryQueryTargetError::InvalidName => {
443 write!(f, "invalid query target name")
444 }
445 }
446 }
447}
448
449impl std::str::FromStr for ArbitraryQueryTarget {
450 type Err = ParseArbitraryQueryTargetError;
451
452 fn from_str(s: &str) -> Result<Self, Self::Err> {
453 match_insensitive!(s.trim(), {
454 "guest" => Ok(ArbitraryQueryTarget::Guest),
455 "record" => Ok(ArbitraryQueryTarget::Record),
456 "system" => Ok(ArbitraryQueryTarget::System),
457 _ => Err(ParseArbitraryQueryTargetError::InvalidName),
458 })
459 }
460}
461
462#[derive(Debug, Clone, Eq, PartialEq)]
463#[non_exhaustive]
464pub enum Targets<T: Hash + Eq + PartialEq> {
465 None,
466 Some(HashSet<T>),
467 All,
468}
469
470impl<T: Target + Hash + Eq + PartialEq> From<T> for Targets<T> {
471 fn from(t: T) -> Self {
472 let mut set = HashSet::new();
473 set.insert(t);
474 Self::Some(set)
475 }
476}
477
478impl<T: Hash + Eq + PartialEq + fmt::Debug + fmt::Display> Targets<T> {
479 pub(crate) fn matches<S>(&self, elem: &S) -> bool
480 where
481 S: ?Sized,
482 T: Target<S>,
483 {
484 match self {
485 Self::None => false,
486 Self::All => true,
487 Self::Some(targets) => targets.iter().any(|t| t.matches(elem)),
488 }
489 }
490}
491
492impl<T: Target + Hash + Eq + PartialEq + fmt::Display> fmt::Display for Targets<T> {
493 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
494 match self {
495 Self::None => write!(f, "none"),
496 Self::All => write!(f, "all"),
497 Self::Some(targets) => {
498 let targets =
499 targets.iter().map(|t| t.to_string()).collect::<Vec<String>>().join(", ");
500 write!(f, "{}", targets)
501 }
502 }
503 }
504}
505
506#[derive(Debug, Clone)]
507#[non_exhaustive]
508pub struct Capabilities {
509 scripting: bool,
510 guest_access: bool,
511 live_query_notifications: bool,
512
513 allow_funcs: Targets<FuncTarget>,
514 deny_funcs: Targets<FuncTarget>,
515 allow_net: Targets<NetTarget>,
516 deny_net: Targets<NetTarget>,
517 allow_rpc: Targets<MethodTarget>,
518 deny_rpc: Targets<MethodTarget>,
519 allow_http: Targets<RouteTarget>,
520 deny_http: Targets<RouteTarget>,
521 allow_experimental: Targets<ExperimentalTarget>,
522 deny_experimental: Targets<ExperimentalTarget>,
523 allow_arbitrary_query: Targets<ArbitraryQueryTarget>,
524 deny_arbitrary_query: Targets<ArbitraryQueryTarget>,
525}
526
527impl fmt::Display for Capabilities {
528 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
529 write!(
530 f,
531 "scripting={}, guest_access={}, live_query_notifications={}, allow_funcs={}, deny_funcs={}, allow_net={}, deny_net={}, allow_rpc={}, deny_rpc={}, allow_http={}, deny_http={}, allow_experimental={}, deny_experimental={}, allow_arbitrary_query={}, deny_arbitrary_query={}",
532 self.scripting, self.guest_access, self.live_query_notifications, self.allow_funcs, self.deny_funcs, self.allow_net, self.deny_net, self.allow_rpc, self.deny_rpc, self.allow_http, self.deny_http, self.allow_experimental, self.deny_experimental, self.allow_arbitrary_query, self.deny_arbitrary_query,
533 )
534 }
535}
536
537impl Default for Capabilities {
538 fn default() -> Self {
539 Self {
540 scripting: false,
541 guest_access: false,
542 live_query_notifications: true,
543
544 allow_funcs: Targets::All,
545 deny_funcs: Targets::None,
546 allow_net: Targets::None,
547 deny_net: Targets::None,
548 allow_rpc: Targets::All,
549 deny_rpc: Targets::None,
550 allow_http: Targets::All,
551 deny_http: Targets::None,
552 allow_experimental: Targets::None,
553 deny_experimental: Targets::None,
554 allow_arbitrary_query: Targets::All,
555 deny_arbitrary_query: Targets::None,
556 }
557 }
558}
559
560impl Capabilities {
561 pub fn all() -> Self {
562 Self {
563 scripting: true,
564 guest_access: true,
565 live_query_notifications: true,
566
567 allow_funcs: Targets::All,
568 deny_funcs: Targets::None,
569 allow_net: Targets::All,
570 deny_net: Targets::None,
571 allow_rpc: Targets::All,
572 deny_rpc: Targets::None,
573 allow_http: Targets::All,
574 deny_http: Targets::None,
575 allow_experimental: Targets::None,
576 deny_experimental: Targets::None,
577 allow_arbitrary_query: Targets::All,
578 deny_arbitrary_query: Targets::None,
579 }
580 }
581
582 pub fn none() -> Self {
583 Self {
584 scripting: false,
585 guest_access: false,
586 live_query_notifications: false,
587
588 allow_funcs: Targets::None,
589 deny_funcs: Targets::None,
590 allow_net: Targets::None,
591 deny_net: Targets::None,
592 allow_rpc: Targets::None,
593 deny_rpc: Targets::None,
594 allow_http: Targets::None,
595 deny_http: Targets::None,
596 allow_experimental: Targets::None,
597 deny_experimental: Targets::None,
598 allow_arbitrary_query: Targets::None,
599 deny_arbitrary_query: Targets::None,
600 }
601 }
602
603 pub fn with_scripting(mut self, scripting: bool) -> Self {
604 self.scripting = scripting;
605 self
606 }
607
608 pub fn with_guest_access(mut self, guest_access: bool) -> Self {
609 self.guest_access = guest_access;
610 self
611 }
612
613 pub fn with_live_query_notifications(mut self, live_query_notifications: bool) -> Self {
614 self.live_query_notifications = live_query_notifications;
615 self
616 }
617
618 pub fn with_functions(mut self, allow_funcs: Targets<FuncTarget>) -> Self {
619 self.allow_funcs = allow_funcs;
620 self
621 }
622
623 pub fn allowed_functions_mut(&mut self) -> &mut Targets<FuncTarget> {
624 &mut self.allow_funcs
625 }
626
627 pub fn without_functions(mut self, deny_funcs: Targets<FuncTarget>) -> Self {
628 self.deny_funcs = deny_funcs;
629 self
630 }
631
632 pub fn denied_functions_mut(&mut self) -> &mut Targets<FuncTarget> {
633 &mut self.deny_funcs
634 }
635
636 pub fn with_experimental(mut self, allow_experimental: Targets<ExperimentalTarget>) -> Self {
637 self.allow_experimental = allow_experimental;
638 self
639 }
640
641 pub fn allowed_experimental_features_mut(&mut self) -> &mut Targets<ExperimentalTarget> {
642 &mut self.allow_experimental
643 }
644
645 pub fn without_experimental(mut self, deny_experimental: Targets<ExperimentalTarget>) -> Self {
646 self.deny_experimental = deny_experimental;
647 self
648 }
649
650 pub fn denied_experimental_features_mut(&mut self) -> &mut Targets<ExperimentalTarget> {
651 &mut self.deny_experimental
652 }
653
654 pub fn with_arbitrary_query(
655 mut self,
656 allow_arbitrary_query: Targets<ArbitraryQueryTarget>,
657 ) -> Self {
658 self.allow_arbitrary_query = allow_arbitrary_query;
659 self
660 }
661
662 pub fn without_arbitrary_query(
663 mut self,
664 deny_arbitrary_query: Targets<ArbitraryQueryTarget>,
665 ) -> Self {
666 self.deny_arbitrary_query = deny_arbitrary_query;
667 self
668 }
669
670 pub fn with_network_targets(mut self, allow_net: Targets<NetTarget>) -> Self {
671 self.allow_net = allow_net;
672 self
673 }
674
675 pub fn allowed_network_targets_mut(&mut self) -> &mut Targets<NetTarget> {
676 &mut self.allow_net
677 }
678
679 pub fn without_network_targets(mut self, deny_net: Targets<NetTarget>) -> Self {
680 self.deny_net = deny_net;
681 self
682 }
683
684 pub fn denied_network_targets_mut(&mut self) -> &mut Targets<NetTarget> {
685 &mut self.deny_net
686 }
687
688 pub fn with_rpc_methods(mut self, allow_rpc: Targets<MethodTarget>) -> Self {
689 self.allow_rpc = allow_rpc;
690 self
691 }
692
693 pub fn without_rpc_methods(mut self, deny_rpc: Targets<MethodTarget>) -> Self {
694 self.deny_rpc = deny_rpc;
695 self
696 }
697
698 pub fn with_http_routes(mut self, allow_http: Targets<RouteTarget>) -> Self {
699 self.allow_http = allow_http;
700 self
701 }
702
703 pub fn without_http_routes(mut self, deny_http: Targets<RouteTarget>) -> Self {
704 self.deny_http = deny_http;
705 self
706 }
707
708 pub fn allows_scripting(&self) -> bool {
709 self.scripting
710 }
711
712 pub fn allows_guest_access(&self) -> bool {
713 self.guest_access
714 }
715
716 pub fn allows_live_query_notifications(&self) -> bool {
717 self.live_query_notifications
718 }
719
720 pub fn allows_function_name(&self, target: &str) -> bool {
721 self.allow_funcs.matches(target) && !self.deny_funcs.matches(target)
722 }
723
724 pub fn allows_experimental(&self, target: &ExperimentalTarget) -> bool {
725 self.allow_experimental.matches(target) && !self.deny_experimental.matches(target)
726 }
727
728 pub fn allows_experimental_name(&self, target: &str) -> bool {
729 self.allow_experimental.matches(target) && !self.deny_experimental.matches(target)
730 }
731
732 pub fn allows_query(&self, target: &ArbitraryQueryTarget) -> bool {
733 self.allow_arbitrary_query.matches(target) && !self.deny_arbitrary_query.matches(target)
734 }
735
736 pub fn allows_query_name(&self, target: &str) -> bool {
737 self.allow_arbitrary_query.matches(target) && !self.deny_arbitrary_query.matches(target)
738 }
739
740 pub fn allows_network_target(&self, target: &NetTarget) -> bool {
741 self.allow_net.matches(target) && !self.deny_net.matches(target)
742 }
743
744 pub fn allows_rpc_method(&self, target: &MethodTarget) -> bool {
745 self.allow_rpc.matches(target) && !self.deny_rpc.matches(target)
746 }
747
748 pub fn allows_http_route(&self, target: &RouteTarget) -> bool {
749 self.allow_http.matches(target) && !self.deny_http.matches(target)
750 }
751}
752
753#[cfg(test)]
754mod tests {
755 use std::str::FromStr;
756 use test_log::test;
757
758 use super::*;
759
760 #[test]
761 fn test_invalid_func_target() {
762 FuncTarget::from_str("te::*st").unwrap_err();
763 FuncTarget::from_str("\0::st").unwrap_err();
764 FuncTarget::from_str("").unwrap_err();
765 FuncTarget::from_str("❤️").unwrap_err();
766 }
767
768 #[test]
769 fn test_func_target() {
770 assert!(FuncTarget::from_str("test").unwrap().matches("test"));
771 assert!(!FuncTarget::from_str("test").unwrap().matches("test2"));
772
773 assert!(!FuncTarget::from_str("test::").unwrap().matches("test"));
774
775 assert!(FuncTarget::from_str("test::*").unwrap().matches("test::name"));
776 assert!(!FuncTarget::from_str("test::*").unwrap().matches("test2::name"));
777
778 assert!(FuncTarget::from_str("test::name").unwrap().matches("test::name"));
779 assert!(!FuncTarget::from_str("test::name").unwrap().matches("test::name2"));
780 }
781
782 #[test]
783 fn test_net_target() {
784 assert!(NetTarget::from_str("10.0.0.0/8")
786 .unwrap()
787 .matches(&NetTarget::from_str("10.0.1.0/24").unwrap()));
788 assert!(NetTarget::from_str("10.0.0.0/8")
789 .unwrap()
790 .matches(&NetTarget::from_str("10.0.1.2").unwrap()));
791 assert!(!NetTarget::from_str("10.0.0.0/8")
792 .unwrap()
793 .matches(&NetTarget::from_str("20.0.1.0/24").unwrap()));
794 assert!(!NetTarget::from_str("10.0.0.0/8")
795 .unwrap()
796 .matches(&NetTarget::from_str("20.0.1.0").unwrap()));
797
798 assert!(NetTarget::from_str("2001:db8::1")
800 .unwrap()
801 .matches(&NetTarget::from_str("2001:db8::1").unwrap()));
802 assert!(NetTarget::from_str("2001:db8::/32")
803 .unwrap()
804 .matches(&NetTarget::from_str("2001:db8::1").unwrap()));
805 assert!(NetTarget::from_str("2001:db8::/32")
806 .unwrap()
807 .matches(&NetTarget::from_str("2001:db8:abcd:12::/64").unwrap()));
808 assert!(!NetTarget::from_str("2001:db8::/32")
809 .unwrap()
810 .matches(&NetTarget::from_str("2001:db9::1").unwrap()));
811 assert!(!NetTarget::from_str("2001:db8::/32")
812 .unwrap()
813 .matches(&NetTarget::from_str("2001:db9:abcd:12::1/64").unwrap()));
814
815 assert!(NetTarget::from_str("example.com")
817 .unwrap()
818 .matches(&NetTarget::from_str("example.com").unwrap()));
819 assert!(NetTarget::from_str("example.com")
820 .unwrap()
821 .matches(&NetTarget::from_str("example.com:80").unwrap()));
822 assert!(!NetTarget::from_str("example.com")
823 .unwrap()
824 .matches(&NetTarget::from_str("www.example.com").unwrap()));
825 assert!(!NetTarget::from_str("example.com")
826 .unwrap()
827 .matches(&NetTarget::from_str("www.example.com:80").unwrap()));
828 assert!(NetTarget::from_str("example.com:80")
829 .unwrap()
830 .matches(&NetTarget::from_str("example.com:80").unwrap()));
831 assert!(!NetTarget::from_str("example.com:80")
832 .unwrap()
833 .matches(&NetTarget::from_str("example.com:443").unwrap()));
834 assert!(!NetTarget::from_str("example.com:80")
835 .unwrap()
836 .matches(&NetTarget::from_str("example.com").unwrap()));
837
838 assert!(
840 NetTarget::from_str("127.0.0.1")
841 .unwrap()
842 .matches(&NetTarget::from_str("127.0.0.1").unwrap()),
843 "Host IPv4 without port matches itself"
844 );
845 assert!(
846 NetTarget::from_str("127.0.0.1")
847 .unwrap()
848 .matches(&NetTarget::from_str("127.0.0.1:80").unwrap()),
849 "Host IPv4 without port matches Host IPv4 with port"
850 );
851 assert!(
852 NetTarget::from_str("10.0.0.0/8")
853 .unwrap()
854 .matches(&NetTarget::from_str("10.0.0.1:80").unwrap()),
855 "IPv4 network matches Host IPv4 with port"
856 );
857 assert!(
858 NetTarget::from_str("127.0.0.1:80")
859 .unwrap()
860 .matches(&NetTarget::from_str("127.0.0.1:80").unwrap()),
861 "Host IPv4 with port matches itself"
862 );
863 assert!(
864 !NetTarget::from_str("127.0.0.1:80")
865 .unwrap()
866 .matches(&NetTarget::from_str("127.0.0.1").unwrap()),
867 "Host IPv4 with port does not match Host IPv4 without port"
868 );
869 assert!(
870 !NetTarget::from_str("127.0.0.1:80")
871 .unwrap()
872 .matches(&NetTarget::from_str("127.0.0.1:443").unwrap()),
873 "Host IPv4 with port does not match Host IPv4 with different port"
874 );
875
876 assert!(
878 NetTarget::from_str("[2001:db8::1]")
879 .unwrap()
880 .matches(&NetTarget::from_str("[2001:db8::1]").unwrap()),
881 "Host IPv6 without port matches itself"
882 );
883 assert!(
884 NetTarget::from_str("[2001:db8::1]")
885 .unwrap()
886 .matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
887 "Host IPv6 without port matches Host IPv6 with port"
888 );
889 assert!(
890 NetTarget::from_str("2001:db8::1")
891 .unwrap()
892 .matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
893 "IPv6 addr matches Host IPv6 with port"
894 );
895 assert!(
896 NetTarget::from_str("2001:db8::/64")
897 .unwrap()
898 .matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
899 "IPv6 network matches Host IPv6 with port"
900 );
901 assert!(
902 NetTarget::from_str("[2001:db8::1]:80")
903 .unwrap()
904 .matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
905 "Host IPv6 with port matches itself"
906 );
907 assert!(
908 !NetTarget::from_str("[2001:db8::1]:80")
909 .unwrap()
910 .matches(&NetTarget::from_str("[2001:db8::1]").unwrap()),
911 "Host IPv6 with port does not match Host IPv6 without port"
912 );
913 assert!(
914 !NetTarget::from_str("[2001:db8::1]:80")
915 .unwrap()
916 .matches(&NetTarget::from_str("[2001:db8::1]:443").unwrap()),
917 "Host IPv6 with port does not match Host IPv6 with different port"
918 );
919
920 assert!(NetTarget::from_str("exam^ple.com").is_err());
922 assert!(NetTarget::from_str("example.com:80:80").is_err());
923 assert!(NetTarget::from_str("11111.3.4.5").is_err());
924 assert!(NetTarget::from_str("2001:db8::1/129").is_err());
925 assert!(NetTarget::from_str("[2001:db8::1").is_err());
926 }
927
928 #[test]
929 fn test_method_target() {
930 assert!(MethodTarget::from_str("query")
931 .unwrap()
932 .matches(&MethodTarget::from_str("query").unwrap()));
933 assert!(MethodTarget::from_str("query")
934 .unwrap()
935 .matches(&MethodTarget::from_str("Query").unwrap()));
936 assert!(MethodTarget::from_str("query")
937 .unwrap()
938 .matches(&MethodTarget::from_str("QUERY").unwrap()));
939 assert!(!MethodTarget::from_str("query")
940 .unwrap()
941 .matches(&MethodTarget::from_str("ping").unwrap()));
942 }
943
944 #[test]
945 fn test_targets() {
946 assert!(Targets::<NetTarget>::All.matches(&NetTarget::from_str("example.com").unwrap()));
947 assert!(Targets::<FuncTarget>::All.matches("http::get"));
948 assert!(!Targets::<NetTarget>::None.matches(&NetTarget::from_str("example.com").unwrap()));
949 assert!(!Targets::<FuncTarget>::None.matches("http::get"));
950 assert!(Targets::<NetTarget>::Some([NetTarget::from_str("example.com").unwrap()].into())
951 .matches(&NetTarget::from_str("example.com").unwrap()));
952 assert!(!Targets::<NetTarget>::Some([NetTarget::from_str("example.com").unwrap()].into())
953 .matches(&NetTarget::from_str("www.example.com").unwrap()));
954 assert!(Targets::<FuncTarget>::Some([FuncTarget::from_str("http::get").unwrap()].into())
955 .matches("http::get"));
956 assert!(!Targets::<FuncTarget>::Some([FuncTarget::from_str("http::get").unwrap()].into())
957 .matches("http::post"));
958 }
959
960 #[test]
961 fn test_capabilities() {
962 {
964 let caps = Capabilities::default().with_scripting(true);
965 assert!(caps.allows_scripting());
966 }
967
968 {
970 let caps = Capabilities::default().with_scripting(false);
971 assert!(!caps.allows_scripting());
972 }
973
974 {
976 let caps = Capabilities::default().with_guest_access(true);
977 assert!(caps.allows_guest_access());
978 }
979
980 {
982 let caps = Capabilities::default().with_guest_access(false);
983 assert!(!caps.allows_guest_access());
984 }
985
986 {
988 let cap = Capabilities::default().with_live_query_notifications(true);
989 assert!(cap.allows_live_query_notifications());
990 }
991
992 {
994 let cap = Capabilities::default().with_live_query_notifications(false);
995 assert!(!cap.allows_live_query_notifications());
996 }
997
998 {
1000 let caps = Capabilities::default()
1001 .with_network_targets(Targets::<NetTarget>::All)
1002 .without_network_targets(Targets::<NetTarget>::None);
1003 assert!(caps.allows_network_target(&NetTarget::from_str("example.com").unwrap()));
1004 assert!(caps.allows_network_target(&NetTarget::from_str("example.com:80").unwrap()));
1005 }
1006
1007 {
1009 let caps = Capabilities::default()
1010 .with_network_targets(Targets::<NetTarget>::All)
1011 .without_network_targets(Targets::<NetTarget>::All);
1012 assert!(!caps.allows_network_target(&NetTarget::from_str("example.com").unwrap()));
1013 assert!(!caps.allows_network_target(&NetTarget::from_str("example.com:80").unwrap()));
1014 }
1015
1016 {
1018 let caps = Capabilities::default()
1019 .with_network_targets(Targets::<NetTarget>::Some(
1020 [NetTarget::from_str("example.com").unwrap()].into(),
1021 ))
1022 .without_network_targets(Targets::<NetTarget>::Some(
1023 [NetTarget::from_str("example.com:80").unwrap()].into(),
1024 ));
1025 assert!(caps.allows_network_target(&NetTarget::from_str("example.com").unwrap()));
1026 assert!(caps.allows_network_target(&NetTarget::from_str("example.com:443").unwrap()));
1027 assert!(!caps.allows_network_target(&NetTarget::from_str("example.com:80").unwrap()));
1028 }
1029
1030 {
1032 let caps = Capabilities::default()
1033 .with_functions(Targets::<FuncTarget>::All)
1034 .without_functions(Targets::<FuncTarget>::None);
1035 assert!(caps.allows_function_name("http::get"));
1036 assert!(caps.allows_function_name("http::post"));
1037 }
1038
1039 {
1041 let caps = Capabilities::default()
1042 .with_functions(Targets::<FuncTarget>::All)
1043 .without_functions(Targets::<FuncTarget>::All);
1044 assert!(!caps.allows_function_name("http::get"));
1045 assert!(!caps.allows_function_name("http::post"));
1046 }
1047
1048 {
1050 let caps = Capabilities::default()
1051 .with_functions(Targets::<FuncTarget>::Some(
1052 [FuncTarget::from_str("http::*").unwrap()].into(),
1053 ))
1054 .without_functions(Targets::<FuncTarget>::Some(
1055 [FuncTarget::from_str("http::post").unwrap()].into(),
1056 ));
1057 assert!(caps.allows_function_name("http::get"));
1058 assert!(caps.allows_function_name("http::put"));
1059 assert!(!caps.allows_function_name("http::post"));
1060 }
1061
1062 {
1064 let caps = Capabilities::default()
1065 .with_rpc_methods(Targets::<MethodTarget>::All)
1066 .without_rpc_methods(Targets::<MethodTarget>::None);
1067 assert!(caps.allows_rpc_method(&MethodTarget::from_str("ping").unwrap()));
1068 assert!(caps.allows_rpc_method(&MethodTarget::from_str("select").unwrap()));
1069 assert!(caps.allows_rpc_method(&MethodTarget::from_str("query").unwrap()));
1070 }
1071
1072 {
1074 let caps = Capabilities::default()
1075 .with_rpc_methods(Targets::<MethodTarget>::All)
1076 .without_rpc_methods(Targets::<MethodTarget>::All);
1077 assert!(!caps.allows_rpc_method(&MethodTarget::from_str("ping").unwrap()));
1078 assert!(!caps.allows_rpc_method(&MethodTarget::from_str("select").unwrap()));
1079 assert!(!caps.allows_rpc_method(&MethodTarget::from_str("query").unwrap()));
1080 }
1081
1082 {
1084 let caps = Capabilities::default()
1085 .with_rpc_methods(Targets::<MethodTarget>::Some(
1086 [
1087 MethodTarget::from_str("select").unwrap(),
1088 MethodTarget::from_str("create").unwrap(),
1089 MethodTarget::from_str("insert").unwrap(),
1090 MethodTarget::from_str("update").unwrap(),
1091 MethodTarget::from_str("query").unwrap(),
1092 MethodTarget::from_str("run").unwrap(),
1093 ]
1094 .into(),
1095 ))
1096 .without_rpc_methods(Targets::<MethodTarget>::Some(
1097 [
1098 MethodTarget::from_str("query").unwrap(),
1099 MethodTarget::from_str("run").unwrap(),
1100 ]
1101 .into(),
1102 ));
1103
1104 assert!(caps.allows_rpc_method(&MethodTarget::from_str("select").unwrap()));
1105 assert!(caps.allows_rpc_method(&MethodTarget::from_str("create").unwrap()));
1106 assert!(caps.allows_rpc_method(&MethodTarget::from_str("insert").unwrap()));
1107 assert!(caps.allows_rpc_method(&MethodTarget::from_str("update").unwrap()));
1108 assert!(!caps.allows_rpc_method(&MethodTarget::from_str("query").unwrap()));
1109 assert!(!caps.allows_rpc_method(&MethodTarget::from_str("run").unwrap()));
1110 }
1111
1112 {
1114 let caps = Capabilities::default()
1115 .with_http_routes(Targets::<RouteTarget>::All)
1116 .without_http_routes(Targets::<RouteTarget>::None);
1117 assert!(caps.allows_http_route(&RouteTarget::from_str("version").unwrap()));
1118 assert!(caps.allows_http_route(&RouteTarget::from_str("export").unwrap()));
1119 assert!(caps.allows_http_route(&RouteTarget::from_str("sql").unwrap()));
1120 }
1121
1122 {
1124 let caps = Capabilities::default()
1125 .with_http_routes(Targets::<RouteTarget>::All)
1126 .without_http_routes(Targets::<RouteTarget>::All);
1127 assert!(!caps.allows_http_route(&RouteTarget::from_str("version").unwrap()));
1128 assert!(!caps.allows_http_route(&RouteTarget::from_str("export").unwrap()));
1129 assert!(!caps.allows_http_route(&RouteTarget::from_str("sql").unwrap()));
1130 }
1131
1132 {
1134 let caps = Capabilities::default()
1135 .with_http_routes(Targets::<RouteTarget>::Some(
1136 [
1137 RouteTarget::from_str("version").unwrap(),
1138 RouteTarget::from_str("import").unwrap(),
1139 RouteTarget::from_str("export").unwrap(),
1140 RouteTarget::from_str("key").unwrap(),
1141 RouteTarget::from_str("sql").unwrap(),
1142 RouteTarget::from_str("rpc").unwrap(),
1143 ]
1144 .into(),
1145 ))
1146 .without_http_routes(Targets::<RouteTarget>::Some(
1147 [RouteTarget::from_str("sql").unwrap(), RouteTarget::from_str("rpc").unwrap()]
1148 .into(),
1149 ));
1150
1151 assert!(caps.allows_http_route(&RouteTarget::from_str("version").unwrap()));
1152 assert!(caps.allows_http_route(&RouteTarget::from_str("import").unwrap()));
1153 assert!(caps.allows_http_route(&RouteTarget::from_str("export").unwrap()));
1154 assert!(caps.allows_http_route(&RouteTarget::from_str("key").unwrap()));
1155 assert!(!caps.allows_http_route(&RouteTarget::from_str("sql").unwrap()));
1156 assert!(!caps.allows_http_route(&RouteTarget::from_str("rpc").unwrap()));
1157 }
1158 }
1159}