1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
//! rspc-tauri: Tauri integration for [rspc](https://rspc.dev).
#![cfg_attr(docsrs2, feature(doc_cfg))]
#![doc(
    html_logo_url = "https://github.com/oscartbeaumont/rspc/raw/main/docs/public/logo.png",
    html_favicon_url = "https://github.com/oscartbeaumont/rspc/raw/main/docs/public/logo.png"
)]

use std::{borrow::Borrow, collections::HashMap, sync::Arc};

use tauri::{
    plugin::{Builder, TauriPlugin},
    AppHandle, Manager, Runtime,
};
use tokio::sync::{mpsc, Mutex};

use rspc::{
    internal::jsonrpc::{self, handle_json_rpc, Sender, SubscriptionMap},
    Router,
};

pub fn plugin<R: Runtime, TCtx, TMeta>(
    router: Arc<Router<TCtx, TMeta>>,
    ctx_fn: impl Fn(AppHandle<R>) -> TCtx + Send + Sync + 'static,
) -> TauriPlugin<R>
where
    TCtx: Send + 'static,
    TMeta: Send + Sync + 'static,
{
    Builder::new("rspc")
        .setup(|app_handle| {
            let (tx, mut rx) = mpsc::unbounded_channel::<jsonrpc::Request>();
            let (resp_tx, mut resp_rx) = mpsc::unbounded_channel::<jsonrpc::Response>();
            // TODO: Don't keep using a tokio mutex. We don't need to hold it over the await point.
            let subscriptions = Arc::new(Mutex::new(HashMap::new()));

            tokio::spawn({
                let app_handle = app_handle.clone();
                async move {
                    while let Some(req) = rx.recv().await {
                        let ctx = ctx_fn(app_handle.clone());
                        let router = router.clone();
                        let mut resp_tx = resp_tx.clone();
                        let subscriptions = subscriptions.clone();
                        tokio::spawn(async move {
                            handle_json_rpc(
                                ctx,
                                req,
                                &router,
                                &mut Sender::ResponseChannel(&mut resp_tx),
                                &mut SubscriptionMap::Mutex(subscriptions.borrow()),
                            )
                            .await;
                        });
                    }
                }
            });

            {
                let app_handle = app_handle.clone();
                tokio::spawn(async move {
                    while let Some(event) = resp_rx.recv().await {
                        let _ = app_handle
                            .emit_all("plugin:rspc:transport:resp", event)
                            .map_err(|err| {
                                #[cfg(feature = "tracing")]
                                tracing::error!("failed to emit JSON-RPC response: {}", err);
                            });
                    }
                });
            }

            app_handle.listen_global("plugin:rspc:transport", move |event| {
                let _ = tx
                    .send(
                        match serde_json::from_str(match event.payload() {
                            Some(v) => v,
                            None => {
                                #[cfg(feature = "tracing")]
                                tracing::error!("Tauri event payload is empty");

                                return;
                            }
                        }) {
                            Ok(v) => v,
                            Err(err) => {
                                #[cfg(feature = "tracing")]
                                tracing::error!("failed to parse JSON-RPC request: {}", err);
                                return;
                            }
                        },
                    )
                    .map_err(|err| {
                        #[cfg(feature = "tracing")]
                        tracing::error!("failed to send JSON-RPC request: {}", err);
                    });
            });

            Ok(())
        })
        .build()
}