leptos_spin/
server_fn.rs

1use crate::request_parts::RequestParts;
2use crate::response_options::ResponseOptions;
3use crate::{
4    request::SpinRequest,
5    response::{SpinBody, SpinResponse},
6};
7use dashmap::DashMap;
8use futures::{SinkExt, StreamExt};
9use http::Method as HttpMethod;
10use leptos::server_fn::middleware::Service;
11/// Leptos Spin Integration for server functions
12use leptos::server_fn::{codec::Encoding, initialize_server_fn_map, ServerFn, ServerFnTraitObj};
13use leptos::{create_runtime, provide_context};
14use multimap::MultiMap;
15use once_cell::sync::Lazy;
16use spin_sdk::http::{Headers, IncomingRequest, OutgoingResponse, ResponseOutparam};
17use url::Url;
18
19#[allow(unused)] // used by server integrations
20type LazyServerFnMap<Req, Res> = Lazy<DashMap<&'static str, ServerFnTraitObj<Req, Res>>>;
21
22static REGISTERED_SERVER_FUNCTIONS: LazyServerFnMap<SpinRequest, SpinResponse> =
23    initialize_server_fn_map!(SpinRequest, SpinResponse);
24
25/// Explicitly register a server function. This is only necessary if you are
26/// running the server in a WASM environment (or a rare environment that the
27/// `inventory`) crate doesn't support. Spin is one of those environments
28pub fn register_explicit<T>()
29where
30    T: ServerFn<ServerRequest = SpinRequest, ServerResponse = SpinResponse> + 'static,
31{
32    REGISTERED_SERVER_FUNCTIONS.insert(
33        T::PATH,
34        ServerFnTraitObj::new(
35            T::PATH,
36            T::InputEncoding::METHOD,
37            |req| Box::pin(T::run_on_server(req)),
38            T::middlewares,
39        ),
40    );
41}
42
43/// The set of all registered server function paths.
44pub fn server_fn_paths() -> impl Iterator<Item = (&'static str, HttpMethod)> {
45    REGISTERED_SERVER_FUNCTIONS
46        .iter()
47        .map(|item| (item.path(), item.method()))
48}
49
50pub async fn handle_server_fns(req: IncomingRequest, resp_out: ResponseOutparam) {
51handle_server_fns_with_context(req, resp_out, ||{}).await;
52}
53pub async fn handle_server_fns_with_context(req: IncomingRequest, resp_out: ResponseOutparam, additional_context: impl Fn() + 'static + Clone + Send) {
54    let pq = req.path_with_query().unwrap_or_default();
55
56    let (spin_res, req_parts, res_options, runtime) = 
57        match crate::server_fn::get_server_fn_by_path(&pq) {
58            Some(lepfn) => {
59            let runtime = create_runtime();
60            let req_parts = RequestParts::new_from_req(&req);
61            provide_context(req_parts.clone());
62            let res_options = ResponseOptions::default_without_headers();
63            provide_context(res_options.clone());
64            additional_context();
65            let spin_req = SpinRequest::new_from_req(req);
66            (lepfn.clone().run(spin_req).await, req_parts, res_options, runtime)
67        },
68            None => panic!("Server FN path {} not found", &pq)
69        
70    };
71        // If the Accept header contains text/html, than this is a request from 
72        // a regular html form, so we should setup a redirect to either the referrer
73        // or the user specified location
74
75        let req_headers = Headers::from_list(req_parts.headers()).expect("Failed to construct Headers from Request Input for a Server Fn.");
76        let accepts_html = &req_headers.get(&"Accept".to_string())[0];
77        let accepts_html_bool = String::from_utf8_lossy(accepts_html).contains("text/html");
78
79        if accepts_html_bool {
80            
81            let referrer = &req_headers.get(&"Referer".to_string())[0];
82            let location = &req_headers.get(&"Location".to_string());
83            if location.is_empty(){
84                res_options.insert_header("location", referrer.to_owned());
85            }
86            // Set status and header for redirect
87            if !res_options.status_is_set(){
88                res_options.set_status(302);
89            }
90
91        } 
92
93    let headers = merge_headers(spin_res.0.headers, res_options.headers());
94    let status = res_options.status().unwrap_or(spin_res.0.status_code);
95    match spin_res.0.body {
96        SpinBody::Plain(r) => {
97            let og = OutgoingResponse::new(headers);
98            og.set_status_code(status).expect("Failed to set Status");
99            let mut ogbod = og.take_body();
100            resp_out.set(og);
101            ogbod.send(r).await.unwrap();
102        }
103        SpinBody::Streaming(mut s) => {
104            let og = OutgoingResponse::new(headers);
105            og.set_status_code(status).expect("Failed to set Status");
106            let mut res_body = og.take_body();
107            resp_out.set(og);
108
109            while let Some(Ok(chunk)) = s.next().await {
110                let _ = res_body.send(chunk.to_vec()).await;
111            }
112        }
113    }
114    runtime.dispose();
115}
116
117/// Returns the server function at the given path
118pub fn get_server_fn_by_path(path: &str) -> Option<ServerFnTraitObj<SpinRequest, SpinResponse>> {
119    // Sanitize Url to prevent query string or ids causing issues. To do that Url wants a full url,
120    // so we give it a fake one. We're only using the path anyway!
121    let full_url =format!("http://leptos.dev{}", path);
122    let Ok(url) = Url::parse(&full_url) else{
123        println!("Failed to parse: {full_url:?}");
124    return None;
125    };
126    REGISTERED_SERVER_FUNCTIONS.get_mut(url.path()).map(|f| f.clone())
127}
128
129/// Merge together two sets of headers, deleting any in the first set of Headers that have a key in
130/// the second set of headers.
131pub fn merge_headers(h1: Headers, h2: Headers) -> Headers {
132    //1. Get all keys in H1 and H2
133    let entries1 = h1.entries();
134    let entries2 = h2.entries();
135
136    let mut mmap1 = MultiMap::new();
137    entries1.iter().for_each(|(k, v)| {
138        mmap1.insert(k, v);
139    });
140    let mut mmap2 = MultiMap::new();
141    entries2.iter().for_each(|(k, v)| {
142        mmap2.insert(k, v);
143    });
144
145    //2. Delete any keys in H1 that are present in H2
146    mmap1.retain(|&k, &_v| mmap2.get(k).is_none());
147
148    //3. Iterate through H2, adding them to H1
149    mmap1.extend(mmap2);
150
151    //4. Profit
152    let mut merged_vec: Vec<(String, Vec<u8>)> = vec![];
153    mmap1.iter_all().for_each(|(k, v)| {
154        for v in v.iter() {
155            merged_vec.push((k.to_string(), v.to_vec()))
156        }
157    });
158    Headers::from_list(&merged_vec).expect("Failed to create headers from merged list")
159}