mpd_utils/
persistent_client.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
use crate::socket::try_get_connection;
use mpd_client::client::{CommandError, ConnectionEvent};
use mpd_client::commands::Command;
use mpd_client::responses::{SongInQueue, Status};
use mpd_client::{commands, Client};
use std::future::Future;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::spawn;
use tokio::sync::broadcast;
use tokio::sync::broadcast::error::RecvError;
use tokio::time::sleep;
use tracing::{debug, error, info};

#[derive(Debug, Clone)]
enum State {
    Disconnected,
    Connected(Arc<Client>),
}

type Channel<T> = (broadcast::Sender<T>, broadcast::Receiver<T>);

/// MPD client which automatically attempts to reconnect
/// if the connection cannot be established or is lost.
///
/// Commands sent to a disconnected client are queued.
#[derive(Debug)]
pub struct PersistentClient {
    host: String,
    retry_interval: Duration,
    state: Arc<RwLock<State>>,
    channel: Channel<Arc<ConnectionEvent>>,
    connection_channel: Channel<Arc<Client>>,
}

impl PersistentClient {
    pub fn new(host: String, retry_interval: Duration) -> Self {
        let channel = broadcast::channel(1024);
        let connection_channel = broadcast::channel(8);

        Self {
            host,
            retry_interval,
            state: Arc::new(RwLock::new(State::Disconnected)),
            channel,
            connection_channel,
        }
    }

    /// Attempts to connect to the MPD host
    /// and begins listening to server events.
    pub fn init(&self) {
        let host = self.host.clone();
        let retry_interval = self.retry_interval;
        let state = self.state.clone();
        let tx = self.channel.0.clone();
        let conn_tx = self.connection_channel.0.clone();

        spawn(async move {
            loop {
                debug!("Attempting to connect to {host}");
                let connection = try_get_connection(&host).await;

                match connection {
                    Ok(connection) => {
                        info!("Connected to '{host}'");

                        let client = Arc::new(connection.0);

                        {
                            *state.write().expect("Failed to get lock on state") =
                                State::Connected(client.clone());
                            conn_tx.send(client).expect("Failed to send event");
                        }

                        let mut events = connection.1;

                        while let Some(event) = events.next().await {
                            if let ConnectionEvent::ConnectionClosed(err) = event {
                                error!("Lost connection to '{host}': {err:?}");
                                *state.write().expect("Failed to get lock on state") =
                                    State::Disconnected;

                                break;
                            }

                            debug!("Sending event: {event:?}");

                            // Wrap in `Arc` because `ConnectionEvent` isn't `Clone`.
                            tx.send(Arc::new(event)).expect("Failed to send event");
                        }
                    }
                    Err(err) => {
                        error!("Failed to connect to '{host}': {err:?}");
                        *state.write().expect("Failed to get lock on state") = State::Disconnected;
                    }
                }

                sleep(retry_interval).await;
            }
        });
    }

    /// Gets the client host address or path
    pub fn host(&self) -> &str {
        &self.host
    }

    /// Gets whether there is a valid connection to the server
    pub fn is_connected(&self) -> bool {
        matches!(
            *self.state.read().expect("Failed to get lock on state"),
            State::Connected(_)
        )
    }

    /// Waits for a valid connection to the server to be established.
    /// If already connected, resolves immediately.
    pub async fn wait_for_client(&self) -> Arc<Client> {
        {
            let state = self.state.read().expect("Failed to get lock on state");

            if let State::Connected(client) = &*state {
                return client.clone();
            }
        }

        let mut rx = self.connection_channel.0.subscribe();
        rx.recv().await.unwrap()
    }

    /// Runs the provided callback as soon as the connected client is available.
    pub async fn with_client<F, Fut, T>(&self, f: F) -> T
    where
        F: FnOnce(Arc<Client>) -> Fut,
        Fut: Future<Output = T>,
    {
        let client = self.wait_for_client().await;
        f(client).await
    }

    /// Receives an event from the MPD server.
    pub async fn recv(&mut self) -> Result<Arc<ConnectionEvent>, RecvError> {
        let rx = &mut self.channel.1;
        rx.recv().await
    }

    /// Creates a new receiver to be able to receive events
    /// outside the context of `&self`.
    ///
    /// When you have access to the client instance, prefer` recv()` instead.
    pub fn subscribe(&self) -> broadcast::Receiver<Arc<ConnectionEvent>> {
        self.channel.0.subscribe()
    }

    /// Runs the provided command on the MPD server.
    ///
    /// Waits for a valid connection and response before the future is completed.
    pub async fn command<C: Command>(&self, cmd: C) -> Result<C::Response, CommandError> {
        self.with_client(|client| async move { client.command(cmd).await })
            .await
    }

    /// Runs the `status` command on the MPD server.
    ///
    /// Waits for a valid connection and response before the future is completed.
    pub async fn status(&self) -> Result<Status, CommandError> {
        self.command(commands::Status).await
    }

    /// Runs the `currentsong` command on the MPD server.
    ///
    /// Waits for a valid connection and response before the future is completed.
    pub async fn current_song(&self) -> Result<Option<SongInQueue>, CommandError> {
        self.command(commands::CurrentSong).await
    }
}

/// Creates a new client on the default localhost TCP address
/// with a connection retry of 5 seconds.
impl Default for PersistentClient {
    fn default() -> Self {
        PersistentClient::new("localhost:6600".to_string(), Duration::from_secs(5))
    }
}

#[cfg(test)]
mod tests {
    use crate::*;
    use mpd_client::commands;

    #[tokio::test]
    async fn test() {
        let client = PersistentClient::default();
        client.init();

        let status = client
            .with_client(|client| async move { client.command(commands::Status).await })
            .await
            .unwrap();

        println!("{:?}", status);
    }
}