surrealdb_core/sql/statements/
foreach.rs

1use crate::ctx::{Context, MutableContext};
2use crate::dbs::Options;
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::sql::{block::Entry, Block, Param, Value};
6
7use reblessive::tree::Stk;
8use revision::revisioned;
9use serde::{Deserialize, Serialize};
10use std::fmt::{self, Display};
11use std::ops::Deref;
12
13#[revisioned(revision = 1)]
14#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
15#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
16#[non_exhaustive]
17pub struct ForeachStatement {
18	pub param: Param,
19	pub range: Value,
20	pub block: Block,
21}
22
23enum ForeachIter {
24	Array(std::vec::IntoIter<Value>),
25	Range(std::iter::Map<std::ops::Range<i64>, fn(i64) -> Value>),
26}
27
28impl Iterator for ForeachIter {
29	type Item = Value;
30
31	fn next(&mut self) -> Option<Self::Item> {
32		match self {
33			ForeachIter::Array(iter) => iter.next(),
34			ForeachIter::Range(iter) => iter.next(),
35		}
36	}
37}
38
39impl ForeachStatement {
40	/// Check if we require a writeable transaction
41	pub(crate) fn writeable(&self) -> bool {
42		self.range.writeable() || self.block.writeable()
43	}
44	/// Process this type returning a computed simple Value
45	///
46	/// Was marked recursive
47	pub(crate) async fn compute(
48		&self,
49		stk: &mut Stk,
50		ctx: &Context,
51		opt: &Options,
52		doc: Option<&CursorDoc>,
53	) -> Result<Value, Error> {
54		// Check the loop data
55		let data = self.range.compute(stk, ctx, opt, doc).await?;
56		let iter = match data {
57			Value::Array(arr) => ForeachIter::Array(arr.into_iter()),
58			Value::Range(r) => {
59				let r: std::ops::Range<i64> = r.deref().to_owned().try_into()?;
60				ForeachIter::Range(r.map(Value::from))
61			}
62
63			v => {
64				return Err(Error::InvalidStatementTarget {
65					value: v.to_string(),
66				})
67			}
68		};
69
70		// Loop over the values
71		'foreach: for v in iter {
72			// Duplicate context
73			let ctx = MutableContext::new(ctx).freeze();
74			// Set the current parameter
75			let key = self.param.0.to_raw();
76			let val = stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await?;
77			let mut ctx = MutableContext::unfreeze(ctx)?;
78			ctx.add_value(key, val.into());
79			let mut ctx = ctx.freeze();
80			// Loop over the code block statements
81			for v in self.block.iter() {
82				// Compute each block entry
83				let res = match v {
84					Entry::Set(v) => {
85						let val = stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await?;
86						let mut c = MutableContext::unfreeze(ctx)?;
87						c.add_value(v.name.to_owned(), val.into());
88						ctx = c.freeze();
89						Ok(Value::None)
90					}
91					Entry::Value(v) => stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await,
92					Entry::Break(v) => v.compute(&ctx, opt, doc).await,
93					Entry::Continue(v) => v.compute(&ctx, opt, doc).await,
94					Entry::Foreach(v) => stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await,
95					Entry::Ifelse(v) => stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await,
96					Entry::Select(v) => stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await,
97					Entry::Create(v) => stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await,
98					Entry::Upsert(v) => stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await,
99					Entry::Update(v) => stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await,
100					Entry::Delete(v) => stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await,
101					Entry::Relate(v) => stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await,
102					Entry::Insert(v) => stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await,
103					Entry::Define(v) => v.compute(stk, &ctx, opt, doc).await,
104					Entry::Alter(v) => v.compute(stk, &ctx, opt, doc).await,
105					Entry::Rebuild(v) => v.compute(stk, &ctx, opt, doc).await,
106					Entry::Remove(v) => v.compute(&ctx, opt, doc).await,
107					Entry::Output(v) => {
108						return stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await;
109					}
110					Entry::Throw(v) => {
111						return stk.run(|stk| v.compute(stk, &ctx, opt, doc)).await;
112					}
113				};
114				// Catch any special errors
115				match res {
116					Err(Error::Continue) => continue 'foreach,
117					Err(Error::Break) => return Ok(Value::None),
118					Err(err) => return Err(err),
119					_ => (),
120				};
121			}
122		}
123		// Ok all good
124		Ok(Value::None)
125	}
126}
127
128impl Display for ForeachStatement {
129	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
130		write!(f, "FOR {} IN {} {}", self.param, self.range, self.block)
131	}
132}