rama_proxy/proxydb/
update.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
use super::ProxyDB;
use arc_swap::ArcSwap;
use rama_core::error::{BoxError, OpaqueError};
use std::{fmt, ops::Deref, sync::Arc};

/// Create a new [`ProxyDB`] updater which allows you to have a (typically in-memory) [`ProxyDB`]
/// which you can update live.
///
/// This construct returns a pair of:
///
/// - [`LiveUpdateProxyDB`]: to be used as the [`ProxyDB`] instead of the inner `T`, dubbed the "reader";
/// - [`LiveUpdateProxyDBSetter`]: to be used as the _only_ way to set the inner `T` as many time as you wish, dubbed the "writer".
///
/// Note that the inner `T` is not yet created when this construct returns this pair.
/// Until you actually called [`LiveUpdateProxyDBSetter::set`] with the inner `T` [`ProxyDB`],
/// any [`ProxyDB`] trait method call to [`LiveUpdateProxyDB`] will fail.
///
/// It is therefore recommended that you immediately set the inner `T` [`ProxyDB`] upon
/// receiving the reader/writer pair, prior to starting to actually use the [`ProxyDB`]
/// in your rama service stack.
///
/// This goal of this updater is to be fast for reading (getting proxies),
/// and slow for the infrequent updates (setting the proxy db). As such it is recommended
/// to not update the [`ProxyDB`] to frequent. An example use case for this updater
/// could be to update your in-memory proxy database every 15 minutes, by populating it from
/// a shared external database (e.g. MySQL`). Failures to create a new `T` ProxyDB should be handled
/// by the Writer, and can be as simple as just logging it and move on without an update.
pub fn proxy_db_updater<T>() -> (LiveUpdateProxyDB<T>, LiveUpdateProxyDBSetter<T>)
where
    T: ProxyDB<Error: Into<BoxError>>,
{
    let data = Arc::new(ArcSwap::from_pointee(None));
    let reader = LiveUpdateProxyDB(data.clone());
    let writer = LiveUpdateProxyDBSetter(data);
    (reader, writer)
}

/// A wrapper around a `T` [`ProxyDB`] which can be updated
/// through the _only_ linked writer [`LiveUpdateProxyDBSetter`].
///
/// See [`proxy_db_updater`] for more details.
pub struct LiveUpdateProxyDB<T>(Arc<ArcSwap<Option<T>>>);

impl<T: fmt::Debug> fmt::Debug for LiveUpdateProxyDB<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_tuple("LiveUpdateProxyDB").field(&self.0).finish()
    }
}

impl<T> Clone for LiveUpdateProxyDB<T> {
    fn clone(&self) -> Self {
        Self(self.0.clone())
    }
}

impl<T> ProxyDB for LiveUpdateProxyDB<T>
where
    T: ProxyDB<Error: Into<BoxError>>,
{
    type Error = BoxError;

    async fn get_proxy_if(
        &self,
        ctx: rama_net::transport::TransportContext,
        filter: super::ProxyFilter,
        predicate: impl super::ProxyQueryPredicate,
    ) -> Result<super::Proxy, Self::Error> {
        match self.0.load().deref().deref() {
            Some(db) => db
                .get_proxy_if(ctx, filter, predicate)
                .await
                .map_err(Into::into),
            None => Err(OpaqueError::from_display(
                "live proxy db: proxy db is None: get_proxy_if unable to proceed",
            )
            .into()),
        }
    }

    async fn get_proxy(
        &self,
        ctx: rama_net::transport::TransportContext,
        filter: super::ProxyFilter,
    ) -> Result<super::Proxy, Self::Error> {
        match self.0.load().deref().deref() {
            Some(db) => db.get_proxy(ctx, filter).await.map_err(Into::into),
            None => Err(OpaqueError::from_display(
                "live proxy db: proxy db is None: get_proxy unable to proceed",
            )
            .into()),
        }
    }
}

/// Writer to set a new [`ProxyDB`] in the linked [`LiveUpdateProxyDB`].
///
/// There can only be one writer [`LiveUpdateProxyDBSetter`] for each
/// collection of [`LiveUpdateProxyDB`] linked to the same internal data `T`.
///
/// See [`proxy_db_updater`] for more details.
pub struct LiveUpdateProxyDBSetter<T>(Arc<ArcSwap<Option<T>>>);

impl<T> LiveUpdateProxyDBSetter<T> {
    /// Set the new `T` [`ProxyDB`] to be used for future [`ProxyDB`]
    /// calls made to the linked [`LiveUpdateProxyDB`] instances.
    pub fn set(&self, db: T) {
        self.0.store(Arc::new(Some(db)))
    }
}

impl<T: fmt::Debug> fmt::Debug for LiveUpdateProxyDBSetter<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_tuple("LiveUpdateProxyDBSetter")
            .field(&self.0)
            .finish()
    }
}

#[cfg(test)]
mod tests {
    use crate::{Proxy, ProxyFilter};
    use rama_net::{
        asn::Asn,
        transport::{TransportContext, TransportProtocol},
    };
    use rama_utils::str::NonEmptyString;

    use super::*;

    #[tokio::test]
    async fn test_empty_live_update_db() {
        let (reader, _) = proxy_db_updater::<Proxy>();
        assert!(reader
            .get_proxy(
                TransportContext {
                    protocol: TransportProtocol::Tcp,
                    app_protocol: None,
                    http_version: None,
                    authority: "proxy.example.com:1080".parse().unwrap(),
                },
                ProxyFilter::default(),
            )
            .await
            .is_err());
    }

    #[tokio::test]
    async fn test_live_update_db_updated() {
        let (reader, writer) = proxy_db_updater();

        assert!(reader
            .get_proxy(
                TransportContext {
                    protocol: TransportProtocol::Tcp,
                    app_protocol: None,
                    http_version: None,
                    authority: "proxy.example.com:1080".parse().unwrap(),
                },
                ProxyFilter::default(),
            )
            .await
            .is_err());

        writer.set(Proxy {
            id: NonEmptyString::from_static("id"),
            address: "authority".parse().unwrap(),
            tcp: true,
            udp: false,
            http: false,
            https: true,
            socks5: false,
            socks5h: false,
            datacenter: true,
            residential: false,
            mobile: true,
            pool_id: Some("pool_id".into()),
            continent: Some("continent".into()),
            country: Some("country".into()),
            state: Some("state".into()),
            city: Some("city".into()),
            carrier: Some("carrier".into()),
            asn: Some(Asn::from_static(1)),
        });

        assert_eq!(
            "id",
            reader
                .get_proxy(
                    TransportContext {
                        protocol: TransportProtocol::Tcp,
                        app_protocol: None,
                        http_version: None,
                        authority: "proxy.example.com:1080".parse().unwrap(),
                    },
                    ProxyFilter::default(),
                )
                .await
                .unwrap()
                .id
        );

        assert!(reader
            .get_proxy(
                TransportContext {
                    protocol: TransportProtocol::Udp,
                    app_protocol: None,
                    http_version: None,
                    authority: "proxy.example.com:1080".parse().unwrap(),
                },
                ProxyFilter::default(),
            )
            .await
            .is_err());

        assert_eq!(
            "id",
            reader
                .get_proxy(
                    TransportContext {
                        protocol: TransportProtocol::Tcp,
                        app_protocol: None,
                        http_version: None,
                        authority: "proxy.example.com:1080".parse().unwrap(),
                    },
                    ProxyFilter::default(),
                )
                .await
                .unwrap()
                .id
        );
    }
}