async_nats/service/
endpoint.rs

1// Copyright 2020-2023 The NATS Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use std::{
15    collections::HashMap,
16    sync::{Arc, Mutex},
17    task::Poll,
18    time::Instant,
19};
20
21use futures::{Stream, StreamExt};
22use serde::{Deserialize, Deserializer, Serialize};
23use tracing::{debug, trace};
24
25use crate::{Client, Subscriber};
26
27use super::{error, Endpoints, Request, ShutdownReceiverFuture};
28
29pub struct Endpoint {
30    pub(crate) requests: Subscriber,
31    pub(crate) stats: Arc<Mutex<Endpoints>>,
32    pub(crate) client: Client,
33    pub(crate) endpoint: String,
34    pub(crate) shutdown: Option<tokio::sync::broadcast::Receiver<()>>,
35    pub(crate) shutdown_future: Option<ShutdownReceiverFuture>,
36}
37
38impl Stream for Endpoint {
39    type Item = Request;
40
41    fn poll_next(
42        mut self: std::pin::Pin<&mut Self>,
43        cx: &mut std::task::Context<'_>,
44    ) -> std::task::Poll<Option<Self::Item>> {
45        trace!("polling for next request");
46        if let Some(mut receiver) = self.shutdown.take() {
47            // Need to initialize `shutdown_future` on first poll
48            self.shutdown_future = Some(Box::pin(async move { receiver.recv().await }));
49        }
50
51        if let Some(shutdown) = self.shutdown_future.as_mut() {
52            match shutdown.as_mut().poll(cx) {
53                Poll::Ready(_result) => {
54                    debug!("got stop broadcast");
55                    self.requests
56                        .sender
57                        .try_send(crate::Command::Unsubscribe {
58                            sid: self.requests.sid,
59                            max: None,
60                        })
61                        .ok();
62
63                    // Clear future, can't be resumed after completion
64                    self.shutdown_future = None;
65                }
66                Poll::Pending => {
67                    trace!("stop broadcast still pending");
68                }
69            }
70        }
71
72        trace!("checking for new messages");
73        match self.requests.poll_next_unpin(cx) {
74            Poll::Ready(message) => {
75                debug!("got next message");
76                match message {
77                    Some(message) => Poll::Ready(Some(Request {
78                        issued: Instant::now(),
79                        stats: self.stats.clone(),
80                        client: self.client.clone(),
81                        message,
82                        endpoint: self.endpoint.clone(),
83                    })),
84                    None => Poll::Ready(None),
85                }
86            }
87
88            Poll::Pending => {
89                trace!("still pending for messages");
90                Poll::Pending
91            }
92        }
93    }
94
95    fn size_hint(&self) -> (usize, Option<usize>) {
96        (0, None)
97    }
98}
99
100impl Endpoint {
101    /// Stops the [Endpoint] and unsubscribes from the subject.
102    pub async fn stop(&mut self) -> Result<(), std::io::Error> {
103        self.requests
104            .unsubscribe()
105            .await
106            .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "failed to unsubscribe"))
107    }
108}
109
110/// Stats of a single endpoint.
111/// Right now, there is only one business endpoint, all other are internals.
112#[derive(Serialize, Deserialize, Debug, Clone, Default)]
113pub(crate) struct Inner {
114    // Response type.
115    #[serde(rename = "type")]
116    pub(crate) kind: String,
117    /// Endpoint name.
118    pub(crate) name: String,
119    /// The subject on which the endpoint is registered
120    pub(crate) subject: String,
121    /// Endpoint specific metadata
122    pub(crate) metadata: HashMap<String, String>,
123    /// Number of requests handled.
124    #[serde(rename = "num_requests")]
125    pub(crate) requests: usize,
126    /// Number of errors occurred.
127    #[serde(rename = "num_errors")]
128    pub(crate) errors: usize,
129    /// Total processing time for all requests.
130    #[serde(default, with = "serde_nanos")]
131    pub(crate) processing_time: std::time::Duration,
132    /// Average processing time for request.
133    #[serde(default, with = "serde_nanos")]
134    pub(crate) average_processing_time: std::time::Duration,
135    /// Last error that occurred.
136    pub(crate) last_error: Option<error::Error>,
137    /// Custom data added by [Config::stats_handler]
138    pub(crate) data: Option<serde_json::Value>,
139    /// Queue group to which this endpoint is assigned to.
140    pub(crate) queue_group: String,
141}
142
143impl From<Inner> for Stats {
144    fn from(inner: Inner) -> Self {
145        Stats {
146            name: inner.name,
147            subject: inner.subject,
148            requests: inner.requests,
149            errors: inner.errors,
150            processing_time: inner.processing_time,
151            average_processing_time: inner.average_processing_time,
152            last_error: inner.last_error,
153            data: inner.data,
154            queue_group: inner.queue_group,
155        }
156    }
157}
158
159#[derive(Serialize, Deserialize, Debug, Clone, Default)]
160pub struct Stats {
161    /// Endpoint name.
162    pub name: String,
163    /// The subject on which the endpoint is registered
164    pub subject: String,
165    /// Number of requests handled.
166    #[serde(rename = "num_requests")]
167    pub requests: usize,
168    /// Number of errors occurred.
169    #[serde(rename = "num_errors")]
170    pub errors: usize,
171    /// Total processing time for all requests.
172    #[serde(default, with = "serde_nanos")]
173    pub processing_time: std::time::Duration,
174    /// Average processing time for request.
175    #[serde(default, with = "serde_nanos")]
176    pub average_processing_time: std::time::Duration,
177    /// Last error that occurred.
178    #[serde(with = "serde_error_string")]
179    pub last_error: Option<error::Error>,
180    /// Custom data added by [crate::service::Config::stats_handler]
181    #[serde(skip_serializing_if = "Option::is_none")]
182    pub data: Option<serde_json::Value>,
183    /// Queue group to which this endpoint is assigned to.
184    pub queue_group: String,
185}
186
187mod serde_error_string {
188    use serde::{Deserialize, Deserializer, Serializer};
189
190    use super::error;
191
192    pub(crate) fn serialize<S>(
193        error: &Option<error::Error>,
194        serializer: S,
195    ) -> Result<S::Ok, S::Error>
196    where
197        S: Serializer,
198    {
199        match error {
200            Some(error) => serializer.serialize_str(&format!("{}:{}", error.code, error.status)),
201            None => serializer.serialize_str(""),
202        }
203    }
204
205    pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result<Option<error::Error>, D::Error>
206    where
207        D: Deserializer<'de>,
208    {
209        let string = String::deserialize(deserializer)?;
210        if string.is_empty() {
211            Ok(None)
212        } else if let Some((code, status)) = &string.split_once(':') {
213            let err_code: usize = code.parse().unwrap_or(0);
214            let status = if err_code == 0 {
215                string.clone()
216            } else {
217                status.to_string()
218            };
219            Ok(Some(error::Error {
220                code: err_code,
221                status,
222            }))
223        } else {
224            Ok(Some(error::Error {
225                code: 0,
226                status: string,
227            }))
228        }
229    }
230}
231
232#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Eq)]
233pub struct Info {
234    /// Name of the endpoint.
235    pub name: String,
236    /// Endpoint subject.
237    pub subject: String,
238    /// Queue group to which this endpoint is assigned.
239    pub queue_group: String,
240    /// Endpoint-specific metadata.
241    #[serde(default, deserialize_with = "null_meta_as_default")]
242    pub metadata: HashMap<String, String>,
243}
244
245pub(crate) fn null_meta_as_default<'de, D>(
246    deserializer: D,
247) -> Result<HashMap<String, String>, D::Error>
248where
249    D: Deserializer<'de>,
250{
251    let metadata: Option<HashMap<String, String>> = Option::deserialize(deserializer)?;
252    Ok(metadata.unwrap_or_default())
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn error_serde() {
261        #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
262        struct WithOptionalError {
263            #[serde(with = "serde_error_string")]
264            error: Option<error::Error>,
265        }
266
267        // serialize and deserialize error with value.
268        let with_error = WithOptionalError {
269            error: Some(error::Error {
270                code: 500,
271                status: "error".to_string(),
272            }),
273        };
274
275        let serialized = serde_json::to_string(&with_error).unwrap();
276        assert_eq!(serialized, r#"{"error":"500:error"}"#);
277
278        let deserialized: WithOptionalError = serde_json::from_str(&serialized).unwrap();
279        assert_eq!(deserialized, with_error);
280
281        // serialize and deserialize error without value.
282        let without_error = WithOptionalError { error: None };
283        let serialized = serde_json::to_string(&without_error).unwrap();
284        assert_eq!(serialized, r#"{"error":""}"#);
285
286        let deserialized: WithOptionalError = serde_json::from_str(&serialized).unwrap();
287        assert_eq!(deserialized, without_error);
288
289        // deserialize error without code.
290        let serialized = r#"{"error":"error"}"#;
291        let deserialized: WithOptionalError = serde_json::from_str(serialized).unwrap();
292        assert_eq!(
293            deserialized,
294            WithOptionalError {
295                error: Some(error::Error {
296                    code: 0,
297                    status: "error".to_string(),
298                })
299            }
300        );
301
302        // deserialize error with invalid code.
303        let serialized = r#"{"error":"invalid:error"}"#;
304        let deserialized: WithOptionalError = serde_json::from_str(serialized).unwrap();
305        assert_eq!(
306            deserialized,
307            WithOptionalError {
308                error: Some(error::Error {
309                    code: 0,
310                    status: "invalid:error".to_string(),
311                })
312            }
313        );
314    }
315}