diff --git a/src/lib.rs b/src/lib.rs index 3df8b59..8fcdc8d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -70,6 +70,9 @@ struct Channel { /// Stream operations while the channel is empty and not closed. stream_ops: Event, + /// Closed operations while the channel is not closed. + closed_ops: Event, + /// The number of currently active `Sender`s. sender_count: AtomicUsize, @@ -89,6 +92,7 @@ impl Channel { // Notify all receive and stream operations. self.recv_ops.notify(usize::MAX); self.stream_ops.notify(usize::MAX); + self.closed_ops.notify(usize::MAX); true } else { @@ -128,6 +132,7 @@ pub fn bounded(cap: usize) -> (Sender, Receiver) { send_ops: Event::new(), recv_ops: Event::new(), stream_ops: Event::new(), + closed_ops: Event::new(), sender_count: AtomicUsize::new(1), receiver_count: AtomicUsize::new(1), }); @@ -169,6 +174,7 @@ pub fn unbounded() -> (Sender, Receiver) { send_ops: Event::new(), recv_ops: Event::new(), stream_ops: Event::new(), + closed_ops: Event::new(), sender_count: AtomicUsize::new(1), receiver_count: AtomicUsize::new(1), }); @@ -258,6 +264,29 @@ impl Sender { }) } + /// Completes when all receiver have dropped. + /// + /// This allows the producers to get notified when interest in the produced values is canceled and immediately stop doing work. + /// + /// # Examples + /// + /// ``` + /// # futures_lite::future::block_on(async { + /// use async_channel::{unbounded, SendError}; + /// + /// let (s, r) = unbounded::(); + /// drop(r); + /// s.closed().await; + /// # }); + /// ``` + pub fn closed(&self) -> Closed<'_, T> { + Closed::_new(ClosedInner { + sender: self, + listener: None, + _pin: PhantomPinned, + }) + } + /// Sends a message into this channel using the blocking strategy. /// /// If the channel is full, this method will block until there is room. @@ -1280,6 +1309,54 @@ impl<'a, T> EventListenerFuture for RecvInner<'a, T> { } } +easy_wrapper! { + /// A future returned by [`Receiver::recv()`]. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct Closed<'a, T>(ClosedInner<'a, T> => ()); + #[cfg(all(feature = "std", not(target_family = "wasm")))] + pub(crate) wait(); +} + +pin_project! { + #[derive(Debug)] + #[project(!Unpin)] + struct ClosedInner<'a, T> { + // Reference to the sender. + sender: &'a Sender, + + // Listener waiting on the channel. + listener: Option, + + // Keeping this type `!Unpin` enables future optimizations. + #[pin] + _pin: PhantomPinned + } +} + +impl<'a, T> EventListenerFuture for ClosedInner<'a, T> { + type Output = (); + + /// Run this future with the given `Strategy`. + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll<()> { + let this = self.project(); + + // Check if the channel is closed. + if !this.sender.is_closed() { + // Channel is not closed yet - now start listening for notifications. + *this.listener = Some(this.sender.channel.closed_ops.listen()); + + // Poll using the given strategy + ready!(S::poll(strategy, &mut *this.listener, cx)); + } + Poll::Ready(()) + } +} + #[cfg(feature = "std")] use std::process::abort; diff --git a/tests/bounded.rs b/tests/bounded.rs index 460cb55..25b01f0 100644 --- a/tests/bounded.rs +++ b/tests/bounded.rs @@ -184,6 +184,29 @@ fn send() { .run(); } +#[cfg(not(target_family = "wasm"))] +#[test] +fn closed() { + let (s, r) = bounded(1); + + Parallel::new() + .add(|| { + future::block_on(s.send(7)).unwrap(); + let before = s.closed(); + let mut before = std::pin::pin!(before); + assert!(future::block_on(future::poll_once(&mut before)).is_none()); + sleep(ms(1000)); + assert_eq!(future::block_on(future::poll_once(before)), Some(())); + assert_eq!(future::block_on(future::poll_once(s.closed())), Some(())); + }) + .add(|| { + assert_eq!(future::block_on(r.recv()), Ok(7)); + sleep(ms(500)); + drop(r); + }) + .run(); +} + #[cfg(not(target_family = "wasm"))] #[test] fn force_send() {