#![cfg_attr(feature = "docs", feature(doc_cfg))]
#![warn(missing_docs)]
#![recursion_limit = "256"]
mod config;
mod conn;
mod error;
mod metrics_utils;
#[cfg(feature = "unstable")]
#[cfg_attr(feature = "docs", doc(cfg(unstable)))]
pub mod runtime;
mod spawn;
mod time;
pub use error::Error;
pub use async_trait::async_trait;
pub use config::Builder;
use config::{Config, InternalConfig, ShareConfig};
use conn::{ActiveConn, ConnState, IdleConn};
use futures_channel::mpsc::{self, Receiver, Sender};
use futures_util::lock::{Mutex, MutexGuard};
use futures_util::select;
use futures_util::FutureExt;
use futures_util::SinkExt;
use futures_util::StreamExt;
use metrics::gauge;
use metrics_utils::DurationHistogramGuard;
pub use spawn::spawn;
use std::fmt;
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc, Weak,
};
use std::time::{Duration, Instant};
#[doc(hidden)]
pub use time::{delay_for, interval};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use crate::metrics_utils::{GaugeGuard, IDLE_CONNECTIONS, WAIT_COUNT, WAIT_DURATION};
const CONNECTION_REQUEST_QUEUE_SIZE: usize = 10000;
#[async_trait]
pub trait Manager: Send + Sync + 'static {
type Connection: Send + 'static;
type Error: Send + Sync + 'static;
fn spawn_task<T>(&self, task: T)
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
spawn(task);
}
async fn connect(&self) -> Result<Self::Connection, Self::Error>;
async fn check(&self, conn: Self::Connection) -> Result<Self::Connection, Self::Error>;
#[inline]
fn validate(&self, _conn: &mut Self::Connection) -> bool {
true
}
}
struct SharedPool<M: Manager> {
config: ShareConfig,
manager: M,
internals: Mutex<PoolInternals<M::Connection>>,
state: PoolState,
semaphore: Arc<Semaphore>,
}
struct PoolInternals<C> {
config: InternalConfig,
free_conns: Vec<IdleConn<C>>,
wait_duration: Duration,
cleaner_ch: Option<Sender<()>>,
}
struct PoolState {
num_open: Arc<AtomicU64>,
max_lifetime_closed: AtomicU64,
max_idle_closed: Arc<AtomicU64>,
wait_count: AtomicU64,
}
impl<C> Drop for PoolInternals<C> {
fn drop(&mut self) {
log::debug!("Pool internal drop");
}
}
pub struct Pool<M: Manager>(Arc<SharedPool<M>>);
impl<M: Manager> Clone for Pool<M> {
fn clone(&self) -> Self {
Pool(self.0.clone())
}
}
impl<M: Manager> fmt::Debug for Pool<M> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Pool")
}
}
#[derive(Debug)]
pub struct State {
pub max_open: u64,
pub connections: u64,
pub in_use: u64,
pub idle: u64,
pub wait_count: u64,
pub wait_duration: Duration,
pub max_idle_closed: u64,
pub max_lifetime_closed: u64,
}
impl<M: Manager> Drop for Pool<M> {
fn drop(&mut self) {}
}
impl<M: Manager> Pool<M> {
pub fn new(manager: M) -> Pool<M> {
Pool::builder().build(manager)
}
pub fn builder() -> Builder<M> {
Builder::new()
}
pub async fn set_max_open_conns(&self, n: u64) {
let mut internals = self.0.internals.lock().await;
internals.config.max_open = n;
if n > 0 && internals.config.max_idle > n {
drop(internals);
self.set_max_idle_conns(n).await;
}
}
pub async fn set_max_idle_conns(&self, n: u64) {
let mut internals = self.0.internals.lock().await;
internals.config.max_idle =
if internals.config.max_open > 0 && n > internals.config.max_open {
internals.config.max_open
} else {
n
};
let max_idle = internals.config.max_idle as usize;
if max_idle > 0 && internals.free_conns.len() > max_idle {
internals.free_conns.truncate(max_idle);
}
}
pub async fn set_conn_max_lifetime(&self, max_lifetime: Option<Duration>) {
assert_ne!(
max_lifetime,
Some(Duration::from_secs(0)),
"max_lifetime must be positive"
);
let mut internals = self.0.internals.lock().await;
internals.config.max_lifetime = max_lifetime;
if let Some(lifetime) = max_lifetime {
match internals.config.max_lifetime {
Some(prev) if lifetime < prev && internals.cleaner_ch.is_some() => {
let _ = internals.cleaner_ch.as_mut().unwrap().send(()).await;
}
_ => (),
}
}
if max_lifetime.is_some()
&& self.0.state.num_open.load(Ordering::Relaxed) > 0
&& internals.cleaner_ch.is_none()
{
log::debug!("run connection cleaner");
let shared1 = Arc::downgrade(&self.0);
let clean_rate = self.0.config.clean_rate;
let (cleaner_ch_sender, cleaner_ch) = mpsc::channel(1);
internals.cleaner_ch = Some(cleaner_ch_sender);
self.0.manager.spawn_task(async move {
connection_cleaner(shared1, cleaner_ch, clean_rate).await;
});
}
}
pub(crate) fn new_inner(manager: M, config: Config) -> Self {
let max_open = if config.max_open == 0 {
CONNECTION_REQUEST_QUEUE_SIZE
} else {
config.max_open as usize
};
gauge!(IDLE_CONNECTIONS).set(0.0);
let (share_config, internal_config) = config.split();
let internals = Mutex::new(PoolInternals {
config: internal_config,
free_conns: Vec::new(),
wait_duration: Duration::from_secs(0),
cleaner_ch: None,
});
let pool_state = PoolState {
num_open: Arc::new(AtomicU64::new(0)),
max_lifetime_closed: AtomicU64::new(0),
wait_count: AtomicU64::new(0),
max_idle_closed: Arc::new(AtomicU64::new(0)),
};
let shared = Arc::new(SharedPool {
config: share_config,
manager,
internals,
semaphore: Arc::new(Semaphore::new(max_open)),
state: pool_state,
});
Pool(shared)
}
pub async fn get(&self) -> Result<Connection<M>, Error<M::Error>> {
match self.0.config.get_timeout {
Some(duration) => self.get_timeout(duration).await,
None => self.inner_get_with_retries().await,
}
}
pub async fn get_timeout(&self, duration: Duration) -> Result<Connection<M>, Error<M::Error>> {
time::timeout(duration, self.inner_get_with_retries()).await
}
async fn inner_get_with_retries(&self) -> Result<Connection<M>, Error<M::Error>> {
let mut try_times: u32 = 0;
let config = &self.0.config;
loop {
try_times += 1;
match self.get_connection().await {
Ok(conn) => return Ok(conn),
Err(Error::BadConn) => {
if try_times == config.max_bad_conn_retries {
return self.get_connection().await;
}
continue;
}
Err(err) => return Err(err),
}
}
}
async fn get_connection(&self) -> Result<Connection<M>, Error<M::Error>> {
let _guard = GaugeGuard::increment(WAIT_COUNT);
let c = self.get_or_create_conn().await?;
let conn = Connection {
pool: self.clone(),
conn: Some(c),
};
Ok(conn)
}
async fn validate_conn(
&self,
internal_config: InternalConfig,
conn: IdleConn<M::Connection>,
) -> Option<IdleConn<M::Connection>> {
if conn.is_brand_new() {
return Some(conn);
}
if conn.expired(internal_config.max_lifetime) {
return None;
}
if conn.idle_expired(internal_config.max_idle_lifetime) {
return None;
}
let needs_health_check = self.0.config.health_check
&& conn.needs_health_check(self.0.config.health_check_interval);
if needs_health_check {
let (raw, split) = conn.split_raw();
let checked_raw = self.0.manager.check(raw).await.ok()?;
let mut checked = split.restore(checked_raw);
checked.mark_checked();
return Some(checked);
}
Some(conn)
}
async fn get_or_create_conn(&self) -> Result<ActiveConn<M::Connection>, Error<M::Error>> {
self.0.state.wait_count.fetch_add(1, Ordering::Relaxed);
let wait_guard = DurationHistogramGuard::start(WAIT_DURATION);
let semaphore = Arc::clone(&self.0.semaphore);
let permit = semaphore
.acquire_owned()
.await
.map_err(|_| Error::PoolClosed)?;
self.0.state.wait_count.fetch_sub(1, Ordering::SeqCst);
let mut internals = self.0.internals.lock().await;
internals.wait_duration += wait_guard.into_elapsed();
let conn = internals.free_conns.pop();
let internal_config = internals.config.clone();
drop(internals);
if let Some(conn) = conn {
if let Some(valid_conn) = self.validate_conn(internal_config, conn).await {
return Ok(valid_conn.into_active(permit));
}
}
let create_r = self.open_new_connection(permit).await;
create_r
}
async fn open_new_connection(
&self,
permit: OwnedSemaphorePermit,
) -> Result<ActiveConn<M::Connection>, Error<M::Error>> {
log::debug!("creating new connection from manager");
match self.0.manager.connect().await {
Ok(c) => {
self.0.state.num_open.fetch_add(1, Ordering::Relaxed);
let state = ConnState::new(
Arc::clone(&self.0.state.num_open),
Arc::clone(&self.0.state.max_idle_closed),
);
Ok(ActiveConn::new(c, permit, state))
}
Err(e) => Err(Error::Inner(e)),
}
}
pub async fn state(&self) -> State {
let internals = self.0.internals.lock().await;
let num_free_conns = internals.free_conns.len() as u64;
let wait_duration = internals.wait_duration;
let max_open = internals.config.max_open;
drop(internals);
State {
max_open,
connections: self.0.state.num_open.load(Ordering::Relaxed),
in_use: self.0.state.num_open.load(Ordering::Relaxed) - num_free_conns,
idle: num_free_conns,
wait_count: self.0.state.wait_count.load(Ordering::Relaxed),
wait_duration,
max_idle_closed: self.0.state.max_idle_closed.load(Ordering::Relaxed),
max_lifetime_closed: self.0.state.max_lifetime_closed.load(Ordering::Relaxed),
}
}
}
async fn recycle_conn<M: Manager>(
shared: &Arc<SharedPool<M>>,
mut conn: ActiveConn<M::Connection>,
) {
if conn_still_valid(shared, &mut conn) {
conn.set_brand_new(false);
let internals = shared.internals.lock().await;
put_idle_conn::<M>(internals, conn);
}
}
fn conn_still_valid<M: Manager>(
shared: &Arc<SharedPool<M>>,
conn: &mut ActiveConn<M::Connection>,
) -> bool {
if !shared.manager.validate(conn.as_raw_mut()) {
log::debug!("bad conn when check in");
return false;
}
true
}
fn put_idle_conn<M: Manager>(
mut internals: MutexGuard<'_, PoolInternals<M::Connection>>,
conn: ActiveConn<M::Connection>,
) {
let idle_conn = conn.into_idle();
if internals.config.max_idle == 0
|| internals.config.max_idle > internals.free_conns.len() as u64
{
internals.free_conns.push(idle_conn);
}
}
async fn connection_cleaner<M: Manager>(
shared: Weak<SharedPool<M>>,
mut cleaner_ch: Receiver<()>,
clean_rate: Duration,
) {
let mut interval = interval(clean_rate);
interval.tick().await;
loop {
select! {
_ = interval.tick().fuse() => (),
r = cleaner_ch.next().fuse() => match r{
Some(()) => (),
None=> return
},
}
if !clean_connection(&shared).await {
return;
}
}
}
async fn clean_connection<M: Manager>(shared: &Weak<SharedPool<M>>) -> bool {
let shared = match shared.upgrade() {
Some(shared) => shared,
None => {
log::debug!("Failed to clean connections");
return false;
}
};
log::debug!("Clean connections");
let mut internals = shared.internals.lock().await;
if shared.state.num_open.load(Ordering::Relaxed) == 0 || internals.config.max_lifetime.is_none()
{
internals.cleaner_ch.take();
return false;
}
let expired = Instant::now() - internals.config.max_lifetime.unwrap();
let mut closing = vec![];
let mut i = 0;
log::debug!(
"clean connections, idle conns {}",
internals.free_conns.len()
);
loop {
if i >= internals.free_conns.len() {
break;
}
if internals.free_conns[i].created_at() < expired {
let c = internals.free_conns.swap_remove(i);
closing.push(c);
continue;
}
i += 1;
}
drop(internals);
shared
.state
.max_lifetime_closed
.fetch_add(closing.len() as u64, Ordering::Relaxed);
true
}
pub struct Connection<M: Manager> {
pool: Pool<M>,
conn: Option<ActiveConn<M::Connection>>,
}
impl<M: Manager> Connection<M> {
pub fn is_brand_new(&self) -> bool {
self.conn.as_ref().unwrap().is_brand_new()
}
pub fn into_inner(mut self) -> M::Connection {
self.conn.take().unwrap().into_raw()
}
}
impl<M: Manager> Drop for Connection<M> {
fn drop(&mut self) {
let Some(conn) = self.conn.take() else {
return;
};
let pool = Arc::clone(&self.pool.0);
self.pool.0.manager.spawn_task(async move {
recycle_conn(&pool, conn).await;
});
}
}
impl<M: Manager> Deref for Connection<M> {
type Target = M::Connection;
fn deref(&self) -> &Self::Target {
self.conn.as_ref().unwrap().as_raw_ref()
}
}
impl<M: Manager> DerefMut for Connection<M> {
fn deref_mut(&mut self) -> &mut M::Connection {
self.conn.as_mut().unwrap().as_raw_mut()
}
}