surrealdb_core/api/
body.rs

1#[cfg(not(target_family = "wasm"))]
2use std::fmt::Display;
3
4#[cfg(not(target_family = "wasm"))]
5use bytes::Bytes;
6#[cfg(not(target_family = "wasm"))]
7use futures::Stream;
8#[cfg(not(target_family = "wasm"))]
9use futures::StreamExt;
10use http::header::CONTENT_TYPE;
11
12use crate::err::Error;
13use crate::rpc::format::{cbor, json, msgpack, revision};
14use crate::sql::Bytesize;
15use crate::sql::Kind;
16use crate::sql::Value;
17
18use super::context::InvocationContext;
19use super::err::ApiError;
20use super::invocation::ApiInvocation;
21
22pub enum ApiBody {
23	#[cfg(not(target_family = "wasm"))]
24	Stream(Box<dyn Stream<Item = Result<Bytes, Box<dyn Display + Send + Sync>>> + Send + Unpin>),
25	Native(Value),
26}
27
28impl ApiBody {
29	#[cfg(not(target_family = "wasm"))]
30	pub fn from_stream<S, E>(stream: S) -> Self
31	where
32		S: Stream<Item = Result<Bytes, E>> + Unpin + Send + 'static,
33		E: Display + Send + Sync + 'static,
34	{
35		let mapped_stream =
36			stream.map(|result| result.map_err(|e| Box::new(e) as Box<dyn Display + Send + Sync>));
37		Self::Stream(Box::new(mapped_stream))
38	}
39
40	pub fn from_value(value: Value) -> Self {
41		Self::Native(value)
42	}
43
44	pub fn is_native(&self) -> bool {
45		matches!(self, Self::Native(_))
46	}
47
48	// The `max` variable is unused in WASM only
49	#[allow(unused_variables)]
50	pub async fn stream(self, max: Option<Bytesize>) -> Result<Vec<u8>, Error> {
51		match self {
52			#[cfg(not(target_family = "wasm"))]
53			Self::Stream(mut stream) => {
54				let max = max.unwrap_or(Bytesize::MAX);
55				let mut size: u64 = 0;
56				let mut bytes: Vec<u8> = Vec::new();
57
58				while let Some(chunk) = stream.next().await {
59					let chunk = chunk.map_err(|_| Error::ApiError(ApiError::InvalidRequestBody))?;
60					size += chunk.len() as u64;
61					if size > max.0 {
62						return Err(ApiError::RequestBodyTooLarge(max).into());
63					}
64
65					bytes.extend_from_slice(&chunk);
66				}
67
68				Ok(bytes)
69			}
70			_ => Err(Error::Unreachable(
71				"Encountered a native body whilst trying to stream one".into(),
72			)),
73		}
74	}
75
76	pub async fn process(
77		self,
78		ctx: &InvocationContext,
79		invocation: &ApiInvocation,
80	) -> Result<Value, Error> {
81		#[allow(irrefutable_let_patterns)] // For WASM this is the only pattern
82		if let ApiBody::Native(value) = self {
83			let max = ctx.request_body_max.to_owned().unwrap_or(Bytesize::MAX);
84			let size = std::mem::size_of_val(&value);
85
86			if size > max.0 as usize {
87				return Err(ApiError::RequestBodyTooLarge(max).into());
88			}
89
90			if ctx.request_body_raw {
91				value.coerce_to(&Kind::Bytes)
92			} else {
93				Ok(value)
94			}
95		} else {
96			let bytes = self.stream(ctx.request_body_max.to_owned()).await?;
97
98			if ctx.request_body_raw {
99				Ok(Value::Bytes(crate::sql::Bytes(bytes)))
100			} else {
101				let content_type =
102					invocation.headers.get(CONTENT_TYPE).and_then(|v| v.to_str().ok());
103
104				let parsed = match content_type {
105					Some("application/json") => json::parse_value(&bytes),
106					Some("application/cbor") => cbor::parse_value(bytes),
107					Some("application/pack") => msgpack::parse_value(bytes),
108					Some("application/surrealdb") => revision::parse_value(bytes),
109					_ => return Ok(Value::Bytes(crate::sql::Bytes(bytes))),
110				};
111
112				parsed.map_err(|_| Error::ApiError(ApiError::BodyDecodeFailure))
113			}
114		}
115	}
116}