1use std::ffi::CString;
4use std::marker::PhantomData;
5use std::os::raw::c_void;
6use std::pin::Pin;
7use std::ptr;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll, Waker};
10use std::time::Duration;
11
12use crate::log::trace;
13use futures_channel::oneshot;
14use futures_util::future::{self, Either, FutureExt};
15use futures_util::pin_mut;
16use futures_util::stream::{Stream, StreamExt};
17use slab::Slab;
18
19use rdkafka_sys as rdsys;
20use rdkafka_sys::types::*;
21
22use crate::client::{Client, NativeQueue};
23use crate::config::{ClientConfig, FromClientConfig, FromClientConfigAndContext};
24use crate::consumer::base_consumer::BaseConsumer;
25use crate::consumer::{
26 CommitMode, Consumer, ConsumerContext, ConsumerGroupMetadata, DefaultConsumerContext,
27 RebalanceProtocol,
28};
29use crate::error::{KafkaError, KafkaResult};
30use crate::groups::GroupList;
31use crate::message::BorrowedMessage;
32use crate::metadata::Metadata;
33use crate::topic_partition_list::{Offset, TopicPartitionList};
34use crate::util::{AsyncRuntime, DefaultRuntime, NativePtr, Timeout};
35
36unsafe extern "C" fn native_message_queue_nonempty_cb(_: *mut RDKafka, opaque_ptr: *mut c_void) {
37 let wakers = &*(opaque_ptr as *const WakerSlab);
38 wakers.wake_all();
39}
40
41unsafe fn enable_nonempty_callback(queue: &NativeQueue, wakers: &Arc<WakerSlab>) {
42 rdsys::rd_kafka_queue_cb_event_enable(
43 queue.ptr(),
44 Some(native_message_queue_nonempty_cb),
45 Arc::as_ptr(wakers) as *mut c_void,
46 )
47}
48
49unsafe fn disable_nonempty_callback(queue: &NativeQueue) {
50 rdsys::rd_kafka_queue_cb_event_enable(queue.ptr(), None, ptr::null_mut())
51}
52
53struct WakerSlab {
54 wakers: Mutex<Slab<Option<Waker>>>,
55}
56
57impl WakerSlab {
58 fn new() -> WakerSlab {
59 WakerSlab {
60 wakers: Mutex::new(Slab::new()),
61 }
62 }
63
64 fn wake_all(&self) {
65 let mut wakers = self.wakers.lock().unwrap();
66 for (_, waker) in wakers.iter_mut() {
67 if let Some(waker) = waker.take() {
68 waker.wake();
69 }
70 }
71 }
72
73 fn register(&self) -> usize {
74 let mut wakers = self.wakers.lock().expect("lock poisoned");
75 wakers.insert(None)
76 }
77
78 fn unregister(&self, slot: usize) {
79 let mut wakers = self.wakers.lock().expect("lock poisoned");
80 wakers.remove(slot);
81 }
82
83 fn set_waker(&self, slot: usize, waker: Waker) {
84 let mut wakers = self.wakers.lock().expect("lock poisoned");
85 wakers[slot] = Some(waker);
86 }
87}
88
89pub struct MessageStream<'a> {
93 wakers: &'a WakerSlab,
94 queue: &'a NativeQueue,
95 slot: usize,
96}
97
98impl<'a> MessageStream<'a> {
99 fn new(wakers: &'a WakerSlab, queue: &'a NativeQueue) -> MessageStream<'a> {
100 let slot = wakers.register();
101 MessageStream {
102 wakers,
103 queue,
104 slot,
105 }
106 }
107
108 fn poll(&self) -> Option<KafkaResult<BorrowedMessage<'a>>> {
109 unsafe {
110 NativePtr::from_ptr(rdsys::rd_kafka_consume_queue(self.queue.ptr(), 0))
111 .map(|p| BorrowedMessage::from_consumer(p, self.queue))
112 }
113 }
114}
115
116impl<'a> Stream for MessageStream<'a> {
117 type Item = KafkaResult<BorrowedMessage<'a>>;
118
119 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
120 if let Some(message) = self.poll() {
123 return Poll::Ready(Some(message));
124 }
125
126 self.wakers.set_waker(self.slot, cx.waker().clone());
131
132 match self.poll() {
137 None => Poll::Pending,
138 Some(message) => Poll::Ready(Some(message)),
139 }
140 }
141}
142
143impl<'a> Drop for MessageStream<'a> {
144 fn drop(&mut self) {
145 self.wakers.unregister(self.slot);
146 }
147}
148
149#[must_use = "Consumer polling thread will stop immediately if unused"]
164pub struct StreamConsumer<C = DefaultConsumerContext, R = DefaultRuntime>
165where
166 C: ConsumerContext,
167{
168 queue: NativeQueue, base: BaseConsumer<C>,
170 wakers: Arc<WakerSlab>,
171 _shutdown_trigger: oneshot::Sender<()>,
172 _runtime: PhantomData<R>,
173}
174
175#[async_trait::async_trait]
176impl<R> FromClientConfig for StreamConsumer<DefaultConsumerContext, R>
177where
178 R: AsyncRuntime,
179{
180 async fn from_config(config: &ClientConfig) -> KafkaResult<Self> {
181 StreamConsumer::from_config_and_context(config, DefaultConsumerContext).await
182 }
183}
184
185#[async_trait::async_trait]
187impl<C, R> FromClientConfigAndContext<C> for StreamConsumer<C, R>
188where
189 C: ConsumerContext + 'static,
190 R: AsyncRuntime,
191{
192 async fn from_config_and_context(config: &ClientConfig, context: C) -> KafkaResult<Self> {
193 let native_config = config.create_native_config()?;
194 let poll_interval = {
195 let millis: u64 = native_config
196 .get("max.poll.interval.ms")?
197 .parse()
198 .expect("librdkafka validated config value is valid u64");
199 Duration::from_millis(millis)
200 };
201
202 let base = BaseConsumer::new(config, native_config, context)?;
203 let native_ptr = base.client().native_ptr() as usize;
204
205 unsafe { rdsys::rd_kafka_poll_set_consumer(base.client().native_ptr()) };
209
210 let queue = base.client().consumer_queue().ok_or_else(|| {
211 KafkaError::ClientCreation("librdkafka failed to create consumer queue".into())
212 })?;
213 let wakers = Arc::new(WakerSlab::new());
214 unsafe { enable_nonempty_callback(&queue, &wakers) }
215
216 let (shutdown_trigger, shutdown_tripwire) = oneshot::channel();
226 let mut shutdown_tripwire = shutdown_tripwire.fuse();
227 R::spawn({
228 let wakers = wakers.clone();
229 async move {
230 trace!("Starting stream consumer wake loop: 0x{:x}", native_ptr);
231 loop {
232 let delay = R::delay_for(poll_interval / 2).fuse();
233 pin_mut!(delay);
234 match future::select(&mut delay, &mut shutdown_tripwire).await {
235 Either::Left(_) => wakers.wake_all(),
236 Either::Right(_) => break,
237 }
238 }
239 trace!("Shut down stream consumer wake loop: 0x{:x}", native_ptr);
240 }
241 });
242
243 Ok(StreamConsumer {
244 base,
245 wakers,
246 queue,
247 _shutdown_trigger: shutdown_trigger,
248 _runtime: PhantomData,
249 })
250 }
251}
252
253impl<C, R> StreamConsumer<C, R>
254where
255 C: ConsumerContext + 'static,
256{
257 pub fn stream(&self) -> MessageStream<'_> {
270 MessageStream::new(&self.wakers, &self.queue)
271 }
272
273 pub async fn recv(&self) -> Result<BorrowedMessage<'_>, KafkaError> {
294 self.stream()
295 .next()
296 .await
297 .expect("kafka streams never terminate")
298 }
299
300 pub fn split_partition_queue(
335 self: &Arc<Self>,
336 topic: &str,
337 partition: i32,
338 ) -> Option<StreamPartitionQueue<C, R>> {
339 let topic = match CString::new(topic) {
340 Ok(topic) => topic,
341 Err(_) => return None,
342 };
343 let queue = unsafe {
344 NativeQueue::from_ptr(rdsys::rd_kafka_queue_get_partition(
345 self.base.client().native_ptr(),
346 topic.as_ptr(),
347 partition,
348 ))
349 };
350 queue.map(|queue| {
351 let wakers = Arc::new(WakerSlab::new());
352 unsafe {
353 rdsys::rd_kafka_queue_forward(queue.ptr(), ptr::null_mut());
354 enable_nonempty_callback(&queue, &wakers);
355 }
356 StreamPartitionQueue {
357 queue,
358 wakers,
359 _consumer: self.clone(),
360 }
361 })
362 }
363}
364
365#[async_trait::async_trait]
366impl<C, R> Consumer<C> for StreamConsumer<C, R>
367where
368 C: ConsumerContext,
369 R: AsyncRuntime,
370{
371 fn client(&self) -> &Client<C> {
372 self.base.client()
373 }
374
375 fn group_metadata(&self) -> Option<ConsumerGroupMetadata> {
376 self.base.group_metadata()
377 }
378
379 fn subscribe(&self, topics: &[&str]) -> KafkaResult<()> {
380 self.base.subscribe(topics)
381 }
382
383 fn unsubscribe(&self) {
384 self.base.unsubscribe();
385 }
386
387 fn assign(&self, assignment: &TopicPartitionList) -> KafkaResult<()> {
388 self.base.assign(assignment)
389 }
390
391 fn unassign(&self) -> KafkaResult<()> {
392 self.base.unassign()
393 }
394
395 fn incremental_assign(&self, assignment: &TopicPartitionList) -> KafkaResult<()> {
396 self.base.incremental_assign(assignment)
397 }
398
399 fn incremental_unassign(&self, assignment: &TopicPartitionList) -> KafkaResult<()> {
400 self.base.incremental_unassign(assignment)
401 }
402
403 fn assignment_lost(&self) -> bool {
404 self.base.assignment_lost()
405 }
406
407 async fn seek<T: Into<Timeout> + Send>(
408 &self,
409 topic: &str,
410 partition: i32,
411 offset: Offset,
412 timeout: T,
413 ) -> KafkaResult<()> {
414 self.base.seek(topic, partition, offset, timeout).await
415 }
416
417 async fn seek_partitions<T: Into<Timeout> + Send>(
418 &self,
419 topic_partition_list: TopicPartitionList,
420 timeout: T,
421 ) -> KafkaResult<TopicPartitionList> {
422 self.base
423 .seek_partitions(topic_partition_list, timeout)
424 .await
425 }
426
427 async fn commit(
428 &self,
429 topic_partition_list: &TopicPartitionList,
430 mode: CommitMode,
431 ) -> KafkaResult<()> {
432 self.base.commit(topic_partition_list, mode).await
433 }
434
435 async fn commit_consumer_state(&self, mode: CommitMode) -> KafkaResult<()> {
436 self.base.commit_consumer_state(mode).await
437 }
438
439 async fn commit_message(
440 &self,
441 message: &BorrowedMessage<'_>,
442 mode: CommitMode,
443 ) -> KafkaResult<()> {
444 self.base.commit_message(message, mode).await
445 }
446
447 fn store_offset(&self, topic: &str, partition: i32, offset: i64) -> KafkaResult<()> {
448 self.base.store_offset(topic, partition, offset)
449 }
450
451 fn store_offset_from_message(&self, message: &BorrowedMessage<'_>) -> KafkaResult<()> {
452 self.base.store_offset_from_message(message)
453 }
454
455 fn store_offsets(&self, tpl: &TopicPartitionList) -> KafkaResult<()> {
456 self.base.store_offsets(tpl)
457 }
458
459 fn subscription(&self) -> KafkaResult<TopicPartitionList> {
460 self.base.subscription()
461 }
462
463 fn assignment(&self) -> KafkaResult<TopicPartitionList> {
464 self.base.assignment()
465 }
466
467 async fn committed<T>(&self, timeout: T) -> KafkaResult<TopicPartitionList>
468 where
469 T: Into<Timeout> + Send,
470 Self: Sized,
471 {
472 self.base.committed(timeout).await
473 }
474
475 async fn committed_offsets<T>(
476 &self,
477 tpl: TopicPartitionList,
478 timeout: T,
479 ) -> KafkaResult<TopicPartitionList>
480 where
481 T: Into<Timeout> + Send,
482 {
483 self.base.committed_offsets(tpl, timeout).await
484 }
485
486 async fn offsets_for_timestamp<T>(
487 &self,
488 timestamp: i64,
489 timeout: T,
490 ) -> KafkaResult<TopicPartitionList>
491 where
492 T: Into<Timeout> + Send,
493 Self: Sized,
494 {
495 self.base.offsets_for_timestamp(timestamp, timeout).await
496 }
497
498 async fn offsets_for_times<T>(
499 &self,
500 timestamps: TopicPartitionList,
501 timeout: T,
502 ) -> KafkaResult<TopicPartitionList>
503 where
504 T: Into<Timeout> + Send,
505 Self: Sized,
506 {
507 self.base.offsets_for_times(timestamps, timeout).await
508 }
509
510 fn position(&self) -> KafkaResult<TopicPartitionList> {
511 self.base.position()
512 }
513
514 async fn fetch_metadata<T>(&self, topic: Option<&str>, timeout: T) -> KafkaResult<Metadata>
515 where
516 T: Into<Timeout> + Send,
517 Self: Sized,
518 {
519 self.base.fetch_metadata(topic, timeout).await
520 }
521
522 async fn fetch_watermarks<T>(
523 &self,
524 topic: &str,
525 partition: i32,
526 timeout: T,
527 ) -> KafkaResult<(i64, i64)>
528 where
529 T: Into<Timeout> + Send + 'static,
530 Self: Sized,
531 {
532 self.base.fetch_watermarks(topic, partition, timeout).await
533 }
534
535 async fn fetch_group_list<T>(&self, group: Option<&str>, timeout: T) -> KafkaResult<GroupList>
536 where
537 T: Into<Timeout> + Send,
538 Self: Sized,
539 {
540 self.base.fetch_group_list(group, timeout).await
541 }
542
543 fn pause(&self, partitions: &TopicPartitionList) -> KafkaResult<()> {
544 self.base.pause(partitions)
545 }
546
547 fn resume(&self, partitions: &TopicPartitionList) -> KafkaResult<()> {
548 self.base.resume(partitions)
549 }
550
551 fn rebalance_protocol(&self) -> RebalanceProtocol {
552 self.base.rebalance_protocol()
553 }
554}
555
556pub struct StreamPartitionQueue<C, R = DefaultRuntime>
561where
562 C: ConsumerContext,
563{
564 queue: NativeQueue,
565 wakers: Arc<WakerSlab>,
566 _consumer: Arc<StreamConsumer<C, R>>,
567}
568
569impl<C, R> StreamPartitionQueue<C, R>
570where
571 C: ConsumerContext,
572{
573 pub fn stream(&self) -> MessageStream<'_> {
586 MessageStream::new(&self.wakers, &self.queue)
587 }
588
589 pub async fn recv(&self) -> Result<BorrowedMessage<'_>, KafkaError> {
613 self.stream()
614 .next()
615 .await
616 .expect("kafka streams never terminate")
617 }
618}
619
620impl<C, R> Drop for StreamPartitionQueue<C, R>
621where
622 C: ConsumerContext,
623{
624 fn drop(&mut self) {
625 unsafe { disable_nonempty_callback(&self.queue) }
626 }
627}