From 6e6a3fd6e727c83ed5744c8c8e2e9fa32bd23a1f Mon Sep 17 00:00:00 2001 From: isra17 Date: Thu, 15 Aug 2024 17:09:53 -0400 Subject: [PATCH 1/2] Reject all futures once a FutureStore is closed. This fix a race condition where a writer_drain future might be created after a disconnection and the FutureStore had all the future rejected. In this case, the new future won't ever be completed since the connection is not writing anymore. --- aiormq/base.py | 12 +++++++++- aiormq/channel.py | 57 +++++++++++++++++++++++++++++++++-------------- 2 files changed, 51 insertions(+), 18 deletions(-) diff --git a/aiormq/base.py b/aiormq/base.py index b8bf9bd..72f351d 100644 --- a/aiormq/base.py +++ b/aiormq/base.py @@ -2,7 +2,9 @@ import asyncio from contextlib import suppress from functools import wraps -from typing import Any, Callable, Coroutine, Optional, Set, TypeVar, Union +from typing import ( + Any, Callable, Coroutine, Optional, Set, TypeVar, Union, Literal +) from weakref import WeakSet from .abc import ( @@ -25,6 +27,7 @@ class FutureStore(AbstractFutureStore): def __init__(self, loop: asyncio.AbstractEventLoop): self.futures = set() self.loop = loop + self.reject_reason: Optional[ExceptionType] | Literal[False] = False self.parent: Optional[FutureStore] = None def __on_task_done( @@ -38,6 +41,12 @@ def remover(*_: Any) -> None: return remover def add(self, future: Union[asyncio.Future, TaskWrapper]) -> None: + if self.reject_reason is not False: + if isinstance(future, TaskWrapper): + future.throw(self.reject_reason or Exception) + elif isinstance(future, asyncio.Future): + future.set_exception(self.reject_reason or Exception) + self.futures.add(future) future.add_done_callback(self.__on_task_done(future)) @@ -46,6 +55,7 @@ def add(self, future: Union[asyncio.Future, TaskWrapper]) -> None: @shield async def reject_all(self, exception: Optional[ExceptionType]) -> None: + self.reject_reason = exception tasks = [] while self.futures: diff --git a/aiormq/channel.py b/aiormq/channel.py index 9d7d576..89f2970 100644 --- a/aiormq/channel.py +++ b/aiormq/channel.py @@ -1,15 +1,16 @@ +from functools import wraps import asyncio import io import logging from collections import OrderedDict -from contextlib import suppress +from contextlib import asynccontextmanager, suppress from functools import partial from io import BytesIO from random import getrandbits from types import MappingProxyType from typing import ( - Any, Awaitable, Callable, Dict, List, Mapping, Optional, Set, Tuple, Type, - Union, + Any, AsyncGenerator, Awaitable, Callable, Dict, List, Mapping, Optional, + Set, Tuple, Type, TypeVar, Union, ) from uuid import UUID @@ -28,7 +29,7 @@ ConfirmationFrameType, ConsumerCallback, DeliveredMessage, ExceptionType, FrameType, GetResultType, ReturnCallback, RpcReturnType, TimeoutType, ) -from .base import Base, task +from .base import Base from .exceptions import ( AMQPChannelError, AMQPError, ChannelAccessRefused, ChannelClosed, ChannelInvalidStateError, ChannelLockedResource, ChannelNotFoundEntity, @@ -48,6 +49,21 @@ }) +T = TypeVar("T") + +TaskFunctionType = Callable[..., T] + + +def task(func: TaskFunctionType) -> TaskFunctionType: + @wraps(func) + async def wrap(self: "Channel", *args: Any, **kwargs: Any) -> Any: + if self.is_closed: + raise ChannelInvalidStateError("%r closed" % self) + return await self.create_task(func(self, *args, **kwargs)) + + return wrap + + def exception_by_code(frame: spec.Channel.Close) -> AMQPError: if frame.reply_code is None: return ChannelClosed(frame.reply_code, frame.reply_text) @@ -140,11 +156,14 @@ def set_close_reason( self.__close_method_id = method_id @property - def lock(self) -> asyncio.Lock: + @asynccontextmanager + async def lock(self) -> AsyncGenerator[None, None]: if self.is_closed: raise ChannelInvalidStateError("%r closed" % self) - - return self.__lock + async with self.__lock: + if self.is_closed: + raise ChannelInvalidStateError("%r closed" % self) + yield async def _get_frame(self) -> FrameType: weight, frame = await self.frames.get() @@ -478,17 +497,21 @@ async def basic_get( countdown = Countdown(timeout) async with countdown.enter_context(self.getter_lock): self.getter = self.create_future() + try: + await self.rpc( + spec.Basic.Get(queue=queue, no_ack=no_ack), + timeout=countdown.get_timeout(), + ) + except BaseException: + self.getter.cancel() + raise + else: + frame: Union[spec.Basic.GetEmpty, spec.Basic.GetOk] + message: DeliveredMessage - await self.rpc( - spec.Basic.Get(queue=queue, no_ack=no_ack), - timeout=countdown.get_timeout(), - ) - - frame: Union[spec.Basic.GetEmpty, spec.Basic.GetOk] - message: DeliveredMessage - - frame, message = await countdown(self.getter) - del self.getter + frame, message = await countdown(self.getter) + finally: + del self.getter return message From a949ce9b51cc90a1c88e69f5d357f33d28ba0a88 Mon Sep 17 00:00:00 2001 From: isra17 Date: Fri, 16 Aug 2024 10:09:58 -0400 Subject: [PATCH 2/2] Add test for the deadlock --- tests/test_connection.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_connection.py b/tests/test_connection.py index ad82610..d61ab18 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -502,6 +502,35 @@ async def run(): await run() +async def test_connection_close_publish(proxy, amqp_url: URL): + url = amqp_url.with_host( + proxy.proxy_host, + ).with_port( + proxy.proxy_port, + ).update_query(heartbeat="1") + + async def run(): + connection = await aiormq.connect(url) + channel = await connection.channel() + declare_ok = await channel.queue_declare(auto_delete=True) + + # This test a bug where a disconnection happening during a call waiting + # for the channel lock would result in a deadlock. Here we get the lock + # so the call to basic_publish end up holding the lock when we have the + # proxy disconnecting. + async with channel.lock: + task = asyncio.create_task(channel.basic_publish( + b"data", routing_key=declare_ok.queue + )) + await asyncio.sleep(0.5) + proxy.disconnect_all() + await asyncio.sleep(0.5) + + with pytest.raises(aiormq.ChannelInvalidStateError): + await task + + await asyncio.wait_for(run(), timeout=5) + PARSE_INT_PARAMS = ( (1, 1), ("1", 1),