surrealdb_core/api/
body.rs1#[cfg(not(target_family = "wasm"))]
2use std::fmt::Display;
3
4#[cfg(not(target_family = "wasm"))]
5use bytes::Bytes;
6#[cfg(not(target_family = "wasm"))]
7use futures::Stream;
8#[cfg(not(target_family = "wasm"))]
9use futures::StreamExt;
10use http::header::CONTENT_TYPE;
11
12use crate::err::Error;
13use crate::rpc::format::{cbor, json, msgpack, revision};
14use crate::sql::Bytesize;
15use crate::sql::Kind;
16use crate::sql::Value;
17
18use super::context::InvocationContext;
19use super::err::ApiError;
20use super::invocation::ApiInvocation;
21
22pub enum ApiBody {
23 #[cfg(not(target_family = "wasm"))]
24 Stream(Box<dyn Stream<Item = Result<Bytes, Box<dyn Display + Send + Sync>>> + Send + Unpin>),
25 Native(Value),
26}
27
28impl ApiBody {
29 #[cfg(not(target_family = "wasm"))]
30 pub fn from_stream<S, E>(stream: S) -> Self
31 where
32 S: Stream<Item = Result<Bytes, E>> + Unpin + Send + 'static,
33 E: Display + Send + Sync + 'static,
34 {
35 let mapped_stream =
36 stream.map(|result| result.map_err(|e| Box::new(e) as Box<dyn Display + Send + Sync>));
37 Self::Stream(Box::new(mapped_stream))
38 }
39
40 pub fn from_value(value: Value) -> Self {
41 Self::Native(value)
42 }
43
44 pub fn is_native(&self) -> bool {
45 matches!(self, Self::Native(_))
46 }
47
48 #[allow(unused_variables)]
50 pub async fn stream(self, max: Option<Bytesize>) -> Result<Vec<u8>, Error> {
51 match self {
52 #[cfg(not(target_family = "wasm"))]
53 Self::Stream(mut stream) => {
54 let max = max.unwrap_or(Bytesize::MAX);
55 let mut size: u64 = 0;
56 let mut bytes: Vec<u8> = Vec::new();
57
58 while let Some(chunk) = stream.next().await {
59 let chunk = chunk.map_err(|_| Error::ApiError(ApiError::InvalidRequestBody))?;
60 size += chunk.len() as u64;
61 if size > max.0 {
62 return Err(ApiError::RequestBodyTooLarge(max).into());
63 }
64
65 bytes.extend_from_slice(&chunk);
66 }
67
68 Ok(bytes)
69 }
70 _ => Err(Error::Unreachable(
71 "Encountered a native body whilst trying to stream one".into(),
72 )),
73 }
74 }
75
76 pub async fn process(
77 self,
78 ctx: &InvocationContext,
79 invocation: &ApiInvocation,
80 ) -> Result<Value, Error> {
81 #[allow(irrefutable_let_patterns)] if let ApiBody::Native(value) = self {
83 let max = ctx.request_body_max.to_owned().unwrap_or(Bytesize::MAX);
84 let size = std::mem::size_of_val(&value);
85
86 if size > max.0 as usize {
87 return Err(ApiError::RequestBodyTooLarge(max).into());
88 }
89
90 if ctx.request_body_raw {
91 value.coerce_to(&Kind::Bytes)
92 } else {
93 Ok(value)
94 }
95 } else {
96 let bytes = self.stream(ctx.request_body_max.to_owned()).await?;
97
98 if ctx.request_body_raw {
99 Ok(Value::Bytes(crate::sql::Bytes(bytes)))
100 } else {
101 let content_type =
102 invocation.headers.get(CONTENT_TYPE).and_then(|v| v.to_str().ok());
103
104 let parsed = match content_type {
105 Some("application/json") => json::parse_value(&bytes),
106 Some("application/cbor") => cbor::parse_value(bytes),
107 Some("application/pack") => msgpack::parse_value(bytes),
108 Some("application/surrealdb") => revision::parse_value(bytes),
109 _ => return Ok(Value::Bytes(crate::sql::Bytes(bytes))),
110 };
111
112 parsed.map_err(|_| Error::ApiError(ApiError::BodyDecodeFailure))
113 }
114 }
115 }
116}