use std::{
collections::HashMap,
pin::Pin,
sync::{
atomic::{
AtomicBool,
Ordering::{Relaxed, SeqCst},
},
Arc,
},
task::{Context, Poll},
};
use rabbitmq_stream_protocol::{
commands::subscribe::OffsetSpecification, message::Message, ResponseKind,
};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tracing::trace;
use crate::error::ConsumerStoreOffsetError;
use crate::{
client::{MessageHandler, MessageResult},
error::{ConsumerCloseError, ConsumerCreateError, ConsumerDeliveryError},
Client, ClientOptions, Environment, MetricsCollector,
};
use futures::{task::AtomicWaker, Stream};
use rand::rngs::StdRng;
use rand::{seq::SliceRandom, SeedableRng};
pub struct Consumer {
name: Option<String>,
receiver: Receiver<Result<Delivery, ConsumerDeliveryError>>,
internal: Arc<ConsumerInternal>,
}
struct ConsumerInternal {
client: Client,
stream: String,
subscription_id: u8,
sender: Sender<Result<Delivery, ConsumerDeliveryError>>,
closed: Arc<AtomicBool>,
waker: AtomicWaker,
metrics_collector: Arc<dyn MetricsCollector>,
}
impl ConsumerInternal {
fn is_closed(&self) -> bool {
self.closed.load(Relaxed)
}
}
pub struct ConsumerBuilder {
pub(crate) consumer_name: Option<String>,
pub(crate) environment: Environment,
pub(crate) offset_specification: OffsetSpecification,
}
impl ConsumerBuilder {
pub async fn build(self, stream: &str) -> Result<Consumer, ConsumerCreateError> {
let mut client = self.environment.create_client().await?;
let collector = self.environment.options.client_options.collector.clone();
if let Some(metadata) = client.metadata(vec![stream.to_string()]).await?.get(stream) {
if let Some(replica) = metadata.replicas.choose(&mut StdRng::from_entropy()) {
tracing::debug!(
"Picked replica {:?} out of possible candidates {:?} for stream {}",
replica,
metadata.replicas,
stream
);
client = Client::connect(ClientOptions {
host: replica.host.clone(),
port: replica.port as u16,
..self.environment.options.client_options
})
.await?;
}
} else {
return Err(ConsumerCreateError::StreamDoesNotExist {
stream: stream.into(),
});
}
let subscription_id = 1;
let (tx, rx) = channel(10000);
let consumer = Arc::new(ConsumerInternal {
subscription_id,
stream: stream.to_string(),
client: client.clone(),
sender: tx,
closed: Arc::new(AtomicBool::new(false)),
waker: AtomicWaker::new(),
metrics_collector: collector,
});
let msg_handler = ConsumerMessageHandler(consumer.clone());
client.set_handler(msg_handler).await;
let response = client
.subscribe(
subscription_id,
stream,
self.offset_specification,
1,
HashMap::new(),
)
.await?;
if response.is_ok() {
Ok(Consumer {
name: self.consumer_name,
receiver: rx,
internal: consumer,
})
} else {
Err(ConsumerCreateError::Create {
stream: stream.to_owned(),
status: response.code().clone(),
})
}
}
pub fn offset(mut self, offset_specification: OffsetSpecification) -> Self {
self.offset_specification = offset_specification;
self
}
pub fn name(mut self, consumer_name: &str) -> Self {
self.consumer_name = Some(String::from(consumer_name));
self
}
}
impl Consumer {
pub fn handle(&self) -> ConsumerHandle {
ConsumerHandle(self.internal.clone())
}
pub fn is_closed(&self) -> bool {
self.internal.is_closed()
}
pub async fn store_offset(&self, offset: u64) -> Result<(), ConsumerStoreOffsetError> {
if let Some(name) = &self.name {
self.internal
.client
.store_offset(name.as_str(), self.internal.stream.as_str(), offset)
.await
.map(Ok)?
} else {
Err(ConsumerStoreOffsetError::NameMissing)
}
}
pub async fn query_offset(&self) -> Result<u64, ConsumerStoreOffsetError> {
if let Some(name) = &self.name {
self.internal
.client
.query_offset(name.clone(), self.internal.stream.as_str())
.await
.map(Ok)?
} else {
Err(ConsumerStoreOffsetError::NameMissing)
}
}
}
impl Stream for Consumer {
type Item = Result<Delivery, ConsumerDeliveryError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.internal.waker.register(cx.waker());
let poll = Pin::new(&mut self.receiver).poll_recv(cx);
match (self.is_closed(), poll.is_ready()) {
(true, false) => Poll::Ready(None),
_ => poll,
}
}
}
pub struct ConsumerHandle(Arc<ConsumerInternal>);
impl ConsumerHandle {
pub async fn close(self) -> Result<(), ConsumerCloseError> {
match self.0.closed.compare_exchange(false, true, SeqCst, SeqCst) {
Ok(false) => {
let response = self.0.client.unsubscribe(self.0.subscription_id).await?;
if response.is_ok() {
self.0.waker.wake();
Ok(())
} else {
Err(ConsumerCloseError::Close {
stream: self.0.stream.clone(),
status: response.code().clone(),
})
}
}
_ => Err(ConsumerCloseError::AlreadyClosed),
}
}
pub async fn is_closed(&self) -> bool {
self.0.is_closed()
}
}
struct ConsumerMessageHandler(Arc<ConsumerInternal>);
#[async_trait::async_trait]
impl MessageHandler for ConsumerMessageHandler {
async fn handle_message(&self, item: MessageResult) -> crate::RabbitMQStreamResult<()> {
match item {
Some(Ok(response)) => {
if let ResponseKind::Deliver(delivery) = response.kind() {
let mut offset = delivery.chunk_first_offset;
let len = delivery.messages.len();
trace!("Got delivery with messages {}", len);
for message in delivery.messages {
let _ = self
.0
.sender
.send(Ok(Delivery {
subscription_id: self.0.subscription_id,
message,
offset,
}))
.await;
offset += 1;
}
let _ = self.0.client.credit(self.0.subscription_id, 1).await;
self.0.metrics_collector.consume(len as u64).await;
}
}
Some(Err(err)) => {
let _ = self.0.sender.send(Err(err.into())).await;
}
None => {
trace!("Closing consumer");
self.0.closed.store(true, Relaxed);
self.0.waker.wake();
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct Delivery {
subscription_id: u8,
message: Message,
offset: u64,
}
impl Delivery {
pub fn subscription_id(&self) -> u8 {
self.subscription_id
}
pub fn message(&self) -> &Message {
&self.message
}
pub fn offset(&self) -> u64 {
self.offset
}
}