use snafu::{ResultExt, Snafu};
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time::interval;
use tokio_executor_trait::Tokio as TokioExecutor;
use tokio_reactor_trait::Tokio as TokioReactor;
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("Failed to connect to RabbitMQ: {source}"))]
Connection { source: lapin::Error },
#[snafu(display("Failed to create channel: {source}"))]
Channel { source: lapin::Error },
#[snafu(display("Failed to close connection: {source}"))]
Close { source: lapin::Error },
}
pub struct RabbitMqChannel {
channel: lapin::Channel,
ref_counter: Arc<AtomicU32>,
}
impl Drop for RabbitMqChannel {
fn drop(&mut self) {
self.ref_counter.fetch_sub(1, Ordering::Relaxed);
}
}
impl Deref for RabbitMqChannel {
type Target = lapin::Channel;
fn deref(&self) -> &Self::Target {
&self.channel
}
}
impl DerefMut for RabbitMqChannel {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.channel
}
}
pub struct RabbitMqPool {
url: String,
min_connections: u32,
max_channels_per_connection: u32,
connections: Mutex<Vec<ConnectionEntry>>,
}
struct ConnectionEntry {
connection: lapin::Connection,
channels: Arc<AtomicU32>,
}
impl RabbitMqPool {
pub fn from_config(
url: &str,
min_connections: u32,
max_channels_per_connection: u32,
) -> Arc<Self> {
let this = Arc::new(Self {
url: url.into(),
min_connections,
max_channels_per_connection,
connections: Mutex::new(vec![]),
});
tokio::spawn(reap_unused_connections(this.clone()));
this
}
pub async fn make_connection(&self) -> Result<lapin::Connection, Error> {
let connection = lapin::Connection::connect(
&self.url,
lapin::ConnectionProperties::default()
.with_executor(TokioExecutor::current())
.with_reactor(TokioReactor),
)
.await
.context(ConnectionSnafu)?;
Ok(connection)
}
pub async fn create_channel(&self) -> Result<RabbitMqChannel, Error> {
let mut connections = self.connections.lock().await;
let entry = connections.iter().find(|entry| {
entry.channels.load(Ordering::Relaxed) < self.max_channels_per_connection
&& entry.connection.status().connected()
});
let entry = if let Some(entry) = entry {
entry
} else {
let connection = self.make_connection().await?;
let channels = Arc::new(AtomicU32::new(0));
connections.push(ConnectionEntry {
connection,
channels,
});
connections.last().expect("Item was just pushed.")
};
let channel = entry
.connection
.create_channel()
.await
.context(ChannelSnafu)?;
entry.channels.fetch_add(1, Ordering::Relaxed);
Ok(RabbitMqChannel {
channel,
ref_counter: entry.channels.clone(),
})
}
pub async fn close(&self, reply_code: u16, reply_message: &str) -> Result<(), Error> {
let mut connections = self.connections.lock().await;
for entry in connections.drain(..) {
entry
.connection
.close(reply_code, reply_message)
.await
.context(CloseSnafu)?;
}
Ok(())
}
}
async fn reap_unused_connections(pool: Arc<RabbitMqPool>) {
let mut interval = interval(Duration::from_secs(10));
loop {
interval.tick().await;
let mut connections = pool.connections.lock().await;
if connections.len() <= pool.min_connections as usize {
continue;
}
let removed_entries = drain_filter(&mut connections, |entry| {
entry.channels.load(Ordering::Relaxed) == 0 || !entry.connection.status().connected()
});
drop(connections);
for entry in removed_entries {
if entry.connection.status().connected() {
if let Err(e) = entry.connection.close(0, "closing").await {
log::error!("Failed to close connection in gc {}", e);
}
}
}
}
}
fn drain_filter<T>(vec: &mut Vec<T>, mut predicate: impl FnMut(&T) -> bool) -> Vec<T> {
let mut i = 0;
let mut ret = Vec::new();
while i < vec.len() {
if predicate(&mut vec[i]) {
ret.push(vec.remove(i));
} else {
i += 1;
}
}
ret
}
#[cfg(test)]
mod test {
use pretty_assertions::assert_eq;
#[test]
fn test_drain_filter() {
let mut items = vec![0, 1, 2, 3, 4, 5];
let mut iterations = 0;
let removed = super::drain_filter(&mut items, |i| {
iterations += 1;
*i > 2
});
assert_eq!(iterations, 6);
assert_eq!(items, vec![0, 1, 2]);
assert_eq!(removed, vec![3, 4, 5]);
}
}