From 21f8242a68529946367bbf2bdbe1026c9283da85 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Tue, 21 Feb 2023 00:53:26 +0300 Subject: [PATCH 01/10] Remove TaskWrapper --- aiormq/abc.py | 47 +++-------------------- aiormq/base.py | 77 +++++++++++++++++++++----------------- aiormq/channel.py | 1 - aiormq/connection.py | 7 ++-- tests/test_connection.py | 10 ++--- tests/test_future_store.py | 26 ++----------- 6 files changed, 61 insertions(+), 107 deletions(-) diff --git a/aiormq/abc.py b/aiormq/abc.py index 774a76b..dd12e7e 100644 --- a/aiormq/abc.py +++ b/aiormq/abc.py @@ -22,43 +22,6 @@ ExceptionType = Union[BaseException, Type[BaseException]] - - -# noinspection PyShadowingNames -class TaskWrapper: - __slots__ = "_exception", "task" - - _exception: Union[BaseException, Type[BaseException]] - task: asyncio.Task - - def __init__(self, task: asyncio.Task): - self.task = task - self._exception = asyncio.CancelledError - - def throw(self, exception: ExceptionType) -> None: - self._exception = exception - self.task.cancel() - - async def __inner(self) -> Any: - try: - return await self.task - except asyncio.CancelledError as e: - raise self._exception from e - - def __await__(self, *args: Any, **kwargs: Any) -> Any: - return self.__inner().__await__() - - def cancel(self) -> None: - return self.throw(asyncio.CancelledError()) - - def __getattr__(self, item: str) -> Any: - return getattr(self.task, item) - - def __repr__(self) -> str: - return "<%s: %s>" % (self.__class__.__name__, repr(self.task)) - - -TaskType = Union[asyncio.Task, TaskWrapper] CoroutineType = Coroutine[Any, None, Any] GetResultType = Union[Basic.GetEmpty, Basic.GetOk] @@ -241,11 +204,11 @@ def marshall( class AbstractFutureStore: - futures: Set[Union[asyncio.Future, TaskType]] + futures: Set[asyncio.Future] loop: asyncio.AbstractEventLoop @abstractmethod - def add(self, future: Union[asyncio.Future, TaskWrapper]) -> None: + def add(self, future: asyncio.Future) -> None: raise NotImplementedError @abstractmethod @@ -253,7 +216,7 @@ def reject_all(self, exception: Optional[ExceptionType]) -> Any: raise NotImplementedError @abstractmethod - def create_task(self, coro: CoroutineType) -> TaskType: + def create_task(self, coro: CoroutineType) -> asyncio.Task: raise NotImplementedError @abstractmethod @@ -273,7 +236,7 @@ def _future_store_child(self) -> AbstractFutureStore: raise NotImplementedError @abstractmethod - def create_task(self, coro: CoroutineType) -> TaskType: + def create_task(self, coro: CoroutineType) -> asyncio.Future: raise NotImplementedError def create_future(self) -> asyncio.Future: @@ -630,5 +593,5 @@ async def update_secret( "CoroutineType", "DeliveredMessage", "DrainResult", "ExceptionType", "FieldArray", "FieldTable", "FieldValue", "FrameReceived", "FrameType", "GetResultType", "ReturnCallback", "RpcReturnType", "SSLCerts", - "TaskType", "TaskWrapper", "TimeoutType", "URLorStr", + "TimeoutType", "URLorStr", ) diff --git a/aiormq/base.py b/aiormq/base.py index b8bf9bd..7030eeb 100644 --- a/aiormq/base.py +++ b/aiormq/base.py @@ -2,23 +2,25 @@ import asyncio from contextlib import suppress from functools import wraps -from typing import Any, Callable, Coroutine, Optional, Set, TypeVar, Union +from typing import Any, Awaitable, Callable, Optional, Set, TypeVar from weakref import WeakSet from .abc import ( - AbstractBase, AbstractFutureStore, CoroutineType, ExceptionType, TaskType, - TaskWrapper, TimeoutType, + AbstractBase, AbstractFutureStore, CoroutineType, ExceptionType, + TimeoutType, ) -from .tools import Countdown, shield +from .tools import Countdown T = TypeVar("T") class FutureStore(AbstractFutureStore): - __slots__ = "futures", "loop", "parent" + __slots__ = ( + "futures", "loop", "parent", "__rejecting", + ) - futures: Set[Union[asyncio.Future, TaskType]] + futures: Set[asyncio.Future] weak_futures: WeakSet loop: asyncio.AbstractEventLoop @@ -26,45 +28,52 @@ def __init__(self, loop: asyncio.AbstractEventLoop): self.futures = set() self.loop = loop self.parent: Optional[FutureStore] = None + self.__rejecting: Optional[ExceptionType] = None - def __on_task_done( - self, future: Union[asyncio.Future, TaskWrapper], - ) -> Callable[..., Any]: - def remover(*_: Any) -> None: - nonlocal future - if future in self.futures: - self.futures.remove(future) - - return remover - - def add(self, future: Union[asyncio.Future, TaskWrapper]) -> None: + def add(self, future: asyncio.Future) -> None: self.futures.add(future) - future.add_done_callback(self.__on_task_done(future)) - + future.add_done_callback(self.futures.discard) if self.parent: self.parent.add(future) - @shield - async def reject_all(self, exception: Optional[ExceptionType]) -> None: + def reject_all(self, exception: Optional[ExceptionType]) -> Awaitable[None]: + self.__rejecting = exception or Exception("Rejected") + tasks = [] while self.futures: - future: Union[TaskType, asyncio.Future] = self.futures.pop() + future: asyncio.Future = self.futures.pop() + + tasks.append(future) if future.done(): continue - if isinstance(future, TaskWrapper): - future.throw(exception or Exception) - tasks.append(future) + if isinstance(future, asyncio.Task): + future.cancel() elif isinstance(future, asyncio.Future): - future.set_exception(exception or Exception) - - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - def create_task(self, coro: CoroutineType) -> TaskType: - task: TaskWrapper = TaskWrapper(self.loop.create_task(coro)) + future.set_exception(self.__rejecting) + + async def rejecter() -> None: + nonlocal tasks + try: + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + finally: + self.__rejecting = None + + return self.loop.create_task(rejecter()) + + async def __task_wrapper(self, coro: CoroutineType) -> Any: + try: + return await coro + except asyncio.CancelledError as e: + if self.__rejecting is None: + raise + raise self.__rejecting from e + + def create_task(self, coro: CoroutineType) -> asyncio.Task: + task: asyncio.Task = self.loop.create_task(self.__task_wrapper(coro)) self.add(task) return task @@ -102,13 +111,13 @@ def _create_closing_future(self) -> asyncio.Future: def _cancel_tasks( self, exc: Optional[ExceptionType] = None, - ) -> Coroutine[Any, Any, None]: + ) -> Awaitable[None]: return self.__future_store.reject_all(exc) def _future_store_child(self) -> AbstractFutureStore: return self.__future_store.get_child() - def create_task(self, coro: CoroutineType) -> TaskType: + def create_task(self, coro: CoroutineType) -> asyncio.Task: return self.__future_store.create_task(coro) def create_future(self) -> asyncio.Future: diff --git a/aiormq/channel.py b/aiormq/channel.py index 4f1cbe0..02b7274 100644 --- a/aiormq/channel.py +++ b/aiormq/channel.py @@ -454,7 +454,6 @@ async def _reader(self) -> None: last_exception, timeout=self.CHANNEL_CLOSE_TIMEOUT, ) - @task async def _on_close(self, exc: Optional[ExceptionType] = None) -> None: if not self.connection.is_opened or self.__close_event.is_set(): return diff --git a/aiormq/connection.py b/aiormq/connection.py index 6238f69..8ecd195 100644 --- a/aiormq/connection.py +++ b/aiormq/connection.py @@ -25,7 +25,7 @@ from .abc import ( AbstractChannel, AbstractConnection, ArgumentsType, ChannelFrame, - ExceptionType, SSLCerts, TaskType, URLorStr, + ExceptionType, SSLCerts, URLorStr, ) from .auth import AuthMechanism from .base import Base, task @@ -233,8 +233,8 @@ class Connection(Base, AbstractConnection): READER_CLOSE_TIMEOUT = 2 - _reader_task: TaskType - _writer_task: TaskType + _reader_task: asyncio.Task + _writer_task: asyncio.Task write_queue: asyncio.Queue server_properties: ArgumentsType connection_tune: spec.Connection.Tune @@ -785,6 +785,7 @@ def publisher_confirms(self) -> Optional[bool]: return None return bool(publisher_confirms) + @task async def channel( self, channel_number: Optional[int] = None, diff --git a/tests/test_connection.py b/tests/test_connection.py index 8eb46a9..9447e54 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -456,7 +456,7 @@ def client_to_server(chunk: bytes) -> bytes: DISCONNECT_OFFSETS = [2 << i for i in range(1, 10)] STAIR_STEPS = list( - itertools.product([0.0, 0.005, 0.05, 0.1], DISCONNECT_OFFSETS) + itertools.product([0.0, 0.005, 0.05, 0.1], DISCONNECT_OFFSETS), ) STAIR_STEPS_IDS = [ f"[{i // len(DISCONNECT_OFFSETS)}] {t}-{s}" @@ -467,10 +467,10 @@ def client_to_server(chunk: bytes) -> bytes: @aiomisc.timeout(30) @pytest.mark.parametrize( "disconnect_time,stair", STAIR_STEPS, - ids=STAIR_STEPS_IDS + ids=STAIR_STEPS_IDS, ) async def test_connection_close_stairway( - disconnect_time: float, stair: int, proxy, amqp_url: URL + disconnect_time: float, stair: int, proxy, amqp_url: URL, ): url = amqp_url.with_host( proxy.proxy_host, @@ -486,12 +486,12 @@ async def run(): channel = await connection.channel() declare_ok = await channel.queue_declare(auto_delete=True) await channel.basic_consume( - declare_ok.queue, queue.put, no_ack=True + declare_ok.queue, queue.put, no_ack=True, ) while True: await channel.basic_publish( - b"test", routing_key=declare_ok.queue + b"test", routing_key=declare_ok.queue, ) message: DeliveredMessage = await queue.get() assert message.body == b"test" diff --git a/tests/test_future_store.py b/tests/test_future_store.py index 5352257..3edfaba 100644 --- a/tests/test_future_store.py +++ b/tests/test_future_store.py @@ -2,30 +2,25 @@ import pytest -from aiormq.abc import TaskWrapper from aiormq.base import FutureStore @pytest.fixture -def root_store(loop): +async def root_store(loop: asyncio.AbstractEventLoop): store = FutureStore(loop=loop) try: yield store finally: - loop.run_until_complete( - store.reject_all(Exception("Cancelling")), - ) + await store.reject_all(Exception("Cancelling")) @pytest.fixture -def child_store(loop, root_store): +async def child_store(loop, root_store: FutureStore): store = root_store.get_child() try: yield store finally: - loop.run_until_complete( - store.reject_all(Exception("Cancelling")), - ) + await store.reject_all(Exception("Cancelling")) async def test_reject_all( @@ -91,16 +86,3 @@ async def coro(store): assert not root_store.futures assert not child_store.futures assert not child.futures - - -async def test_task_wrapper(loop): - future = loop.create_future() - wrapped = TaskWrapper(future) - - wrapped.throw(RuntimeError()) - - with pytest.raises(asyncio.CancelledError): - await future - - with pytest.raises(RuntimeError): - await wrapped From 4e9252c90019a347097d37aad373ac8a5914d12f Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Tue, 21 Feb 2023 00:55:31 +0300 Subject: [PATCH 02/10] Rename rejector to wait_rejected --- aiormq/base.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/aiormq/base.py b/aiormq/base.py index 7030eeb..0da9221 100644 --- a/aiormq/base.py +++ b/aiormq/base.py @@ -48,21 +48,20 @@ def reject_all(self, exception: Optional[ExceptionType]) -> Awaitable[None]: if future.done(): continue - - if isinstance(future, asyncio.Task): + elif isinstance(future, asyncio.Task): future.cancel() elif isinstance(future, asyncio.Future): future.set_exception(self.__rejecting) - async def rejecter() -> None: + async def wait_rejected() -> None: nonlocal tasks try: - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) + if not tasks: + return + await asyncio.gather(*tasks, return_exceptions=True) finally: self.__rejecting = None - - return self.loop.create_task(rejecter()) + return self.loop.create_task(wait_rejected()) async def __task_wrapper(self, coro: CoroutineType) -> Any: try: From 94ba70c372229bc899a9f1521b95574d914b48c7 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Tue, 21 Feb 2023 00:56:41 +0300 Subject: [PATCH 03/10] Use RuntimeError instead Exception --- aiormq/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiormq/base.py b/aiormq/base.py index 0da9221..dba5562 100644 --- a/aiormq/base.py +++ b/aiormq/base.py @@ -37,7 +37,7 @@ def add(self, future: asyncio.Future) -> None: self.parent.add(future) def reject_all(self, exception: Optional[ExceptionType]) -> Awaitable[None]: - self.__rejecting = exception or Exception("Rejected") + self.__rejecting = exception or RuntimeError("Has been rejected") tasks = [] From 9604b05c4b4eae0e00108c20b1ced4e1c7ed3d70 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Tue, 21 Feb 2023 01:10:45 +0300 Subject: [PATCH 04/10] replace deprecated loop fixture to event_loop --- aiormq/types.py | 11 -------- tests/conftest.py | 8 +++--- tests/test_channel.py | 8 +++--- tests/test_connection.py | 55 +++++++++++++++++++++----------------- tests/test_future_store.py | 12 ++++----- tests/test_tools.py | 4 +-- 6 files changed, 47 insertions(+), 51 deletions(-) delete mode 100644 aiormq/types.py diff --git a/aiormq/types.py b/aiormq/types.py deleted file mode 100644 index 65e5947..0000000 --- a/aiormq/types.py +++ /dev/null @@ -1,11 +0,0 @@ -import warnings - -from .abc import * # noqa - - -warnings.warn( - "aiormq.types was deprecated and will be removed in " - "one of next major releases. Use aiormq.abc instead.", - category=DeprecationWarning, - stacklevel=2, -) diff --git a/tests/conftest.py b/tests/conftest.py index 0261626..bb6214b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,8 +50,8 @@ async def amqp_url(request): @pytest.fixture -async def amqp_connection(amqp_url, loop): - connection = Connection(amqp_url, loop=loop) +async def amqp_connection(amqp_url, event_loop): + connection = Connection(amqp_url, loop=event_loop) async with connection: yield connection @@ -131,13 +131,13 @@ async def proxy(tcp_proxy, localhost, amqp_url: URL): @pytest.fixture -async def proxy_connection(proxy: TCPProxy, amqp_url: URL, loop): +async def proxy_connection(proxy: TCPProxy, amqp_url: URL, event_loop): url = amqp_url.with_host( "localhost", ).with_port( proxy.proxy_port, ) - connection = Connection(url, loop=loop) + connection = Connection(url, loop=event_loop) await connection.connect() diff --git a/tests/test_channel.py b/tests/test_channel.py index b9652e7..1f0ee05 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -67,17 +67,17 @@ async def test_blank_body(amqp_channel: aiormq.Channel): assert message.body == b"foo bar" -async def test_bad_consumer(amqp_channel: aiormq.Channel, loop): +async def test_bad_consumer(amqp_channel: aiormq.Channel, event_loop): channel: aiormq.Channel = amqp_channel await channel.basic_qos(prefetch_count=1) declare_ok = await channel.queue_declare() - future = loop.create_future() + future = event_loop.create_future() await channel.basic_publish(b"urgent", routing_key=declare_ok.queue) - consumer_tag = loop.create_future() + consumer_tag = event_loop.create_future() async def bad_consumer(message): await channel.basic_cancel(await consumer_tag) @@ -94,7 +94,7 @@ async def bad_consumer(message): await channel.basic_reject(message.delivery.delivery_tag, requeue=True) assert message.body == b"urgent" - future = loop.create_future() + future = event_loop.create_future() await channel.basic_consume( declare_ok.queue, future.set_result, no_ack=True, diff --git a/tests/test_connection.py b/tests/test_connection.py index 9447e54..58cdc7a 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -128,18 +128,18 @@ async def test_channel_close(amqp_connection): assert channel.number not in amqp_connection.channels -async def test_conncetion_reject(loop): +async def test_conncetion_reject(event_loop): with pytest.raises(ConnectionError): await aiormq.connect( - "amqp://guest:guest@127.0.0.1:59999/", loop=loop, + "amqp://guest:guest@127.0.0.1:59999/", loop=event_loop, ) connection = aiormq.Connection( - "amqp://guest:guest@127.0.0.1:59999/", loop=loop, + "amqp://guest:guest@127.0.0.1:59999/", loop=event_loop, ) with pytest.raises(ConnectionError): - await loop.create_task(connection.connect()) + await event_loop.create_task(connection.connect()) async def test_auth_base(amqp_connection): @@ -147,7 +147,7 @@ async def test_auth_base(amqp_connection): AuthBase(amqp_connection).marshal() -async def test_auth_plain(amqp_connection, loop): +async def test_auth_plain(amqp_connection, event_loop): auth = PlainAuth(amqp_connection).marshal() assert auth == PlainAuth(amqp_connection).marshal() @@ -157,7 +157,7 @@ async def test_auth_plain(amqp_connection, loop): connection = aiormq.Connection( amqp_connection.url.with_user("foo").with_password("bar"), - loop=loop, + loop=event_loop, ) auth = PlainAuth(connection).marshal() @@ -171,7 +171,7 @@ async def test_auth_plain(amqp_connection, loop): assert auth.marshal() == "boo" -async def test_auth_external(loop): +async def test_auth_external(event_loop): url = AMQP_URL.with_scheme("amqps") url.update_query(auth="external") @@ -204,55 +204,55 @@ async def test_channel_closed(amqp_connection): await amqp_connection.close() -async def test_timeout_default(loop): - connection = aiormq.Connection(AMQP_URL, loop=loop) +async def test_timeout_default(event_loop): + connection = aiormq.Connection(AMQP_URL, loop=event_loop) await connection.connect() assert connection.timeout == 60 await connection.close() -async def test_timeout_1000(loop): +async def test_timeout_1000(event_loop): url = AMQP_URL.update_query(timeout=1000) - connection = aiormq.Connection(url, loop=loop) + connection = aiormq.Connection(url, loop=event_loop) await connection.connect() assert connection.timeout await connection.close() -async def test_heartbeat_0(loop): +async def test_heartbeat_0(event_loop): url = AMQP_URL.update_query(heartbeat=0) - connection = aiormq.Connection(url, loop=loop) + connection = aiormq.Connection(url, loop=event_loop) await connection.connect() assert connection.connection_tune.heartbeat == 0 await connection.close() -async def test_heartbeat_default(loop): - connection = aiormq.Connection(AMQP_URL, loop=loop) +async def test_heartbeat_default(event_loop): + connection = aiormq.Connection(AMQP_URL, loop=event_loop) await connection.connect() assert connection.connection_tune.heartbeat == 60 await connection.close() -async def test_heartbeat_above_range(loop): +async def test_heartbeat_above_range(event_loop): url = AMQP_URL.update_query(heartbeat=70000) - connection = aiormq.Connection(url, loop=loop) + connection = aiormq.Connection(url, loop=event_loop) await connection.connect() assert connection.connection_tune.heartbeat == 0 await connection.close() -async def test_heartbeat_under_range(loop): +async def test_heartbeat_under_range(event_loop): url = AMQP_URL.update_query(heartbeat=-1) - connection = aiormq.Connection(url, loop=loop) + connection = aiormq.Connection(url, loop=event_loop) await connection.connect() assert connection.connection_tune.heartbeat == 0 await connection.close() -async def test_heartbeat_not_int(loop): +async def test_heartbeat_not_int(event_loop): url = AMQP_URL.update_query(heartbeat="None") - connection = aiormq.Connection(url, loop=loop) + connection = aiormq.Connection(url, loop=event_loop) await connection.connect() assert connection.connection_tune.heartbeat == 0 await connection.close() @@ -322,7 +322,7 @@ async def test_return_message(amqp_connection: aiormq.Connection): assert result.delivery.routing_key == routing_key -async def test_cancel_on_queue_deleted(amqp_connection, loop): +async def test_cancel_on_queue_deleted(amqp_connection, event_loop): conn: aiormq.Connection = amqp_connection channel: aiormq.Channel = await conn.channel() deaclare_ok = await channel.queue_declare(auto_delete=True) @@ -377,8 +377,8 @@ async def test_ssl_context(): @pytest.mark.parametrize("url,vhost", URL_VHOSTS) -async def test_connection_urls_vhosts(url, vhost, loop): - assert aiormq.Connection(url, loop=loop).vhost == vhost +async def test_connection_urls_vhosts(url, vhost, event_loop): + assert aiormq.Connection(url, loop=event_loop).vhost == vhost async def test_update_secret(amqp_connection, amqp_url: URL): @@ -499,3 +499,10 @@ async def run(): for _ in range(5): with pytest.raises(aiormq.AMQPError): await run() + + +async def test_connection_close_reader(amqp_connection: aiormq.Connection): + amqp_connection._reader_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await amqp_connection.channel() diff --git a/tests/test_future_store.py b/tests/test_future_store.py index 3edfaba..6a45d05 100644 --- a/tests/test_future_store.py +++ b/tests/test_future_store.py @@ -6,8 +6,8 @@ @pytest.fixture -async def root_store(loop: asyncio.AbstractEventLoop): - store = FutureStore(loop=loop) +async def root_store(event_loop: asyncio.AbstractEventLoop): + store = FutureStore(loop=event_loop) try: yield store finally: @@ -15,7 +15,7 @@ async def root_store(loop: asyncio.AbstractEventLoop): @pytest.fixture -async def child_store(loop, root_store: FutureStore): +async def child_store(event_loop, root_store: FutureStore): store = root_store.get_child() try: yield store @@ -24,7 +24,7 @@ async def child_store(loop, root_store: FutureStore): async def test_reject_all( - loop, root_store: FutureStore, child_store: FutureStore, + event_loop, root_store: FutureStore, child_store: FutureStore, ): future1 = root_store.create_future() @@ -43,7 +43,7 @@ async def test_reject_all( async def test_result( - loop, root_store: FutureStore, child_store: FutureStore, + event_loop, root_store: FutureStore, child_store: FutureStore, ): async def result(): await asyncio.sleep(0.1) @@ -53,7 +53,7 @@ async def result(): async def test_siblings( - loop, root_store: FutureStore, child_store: FutureStore, + event_loop, root_store: FutureStore, child_store: FutureStore, ): async def coro(store): await asyncio.sleep(0.1) diff --git a/tests/test_tools.py b/tests/test_tools.py index ad6715e..b5fba9f 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -44,11 +44,11 @@ def return_coroutine(): @pytest.mark.parametrize("func,result", AWAITABLE_FUNCS) -async def test_awaitable(func, result, loop): +async def test_awaitable(func, result, event_loop): assert await awaitable(func)() == result -async def test_countdown(loop): +async def test_countdown(event_loop): countdown = Countdown(timeout=0.1) await countdown(asyncio.sleep(0)) From c1767313d5ac8837b3f5aa62b0d1d18e9925761e Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Tue, 21 Feb 2023 01:11:00 +0300 Subject: [PATCH 05/10] remove useless shield --- aiormq/tools.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/aiormq/tools.py b/aiormq/tools.py index de6ccdf..b6b41e2 100644 --- a/aiormq/tools.py +++ b/aiormq/tools.py @@ -22,14 +22,6 @@ def censor_url(url: URL) -> URL: return url -def shield(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]: - @wraps(func) - def wrap(*args: Any, **kwargs: Any) -> Awaitable[T]: - return asyncio.shield(func(*args, **kwargs)) - - return wrap - - def awaitable( func: Callable[..., Union[T, Awaitable[T]]], ) -> Callable[..., Coroutine[Any, Any, T]]: From d0d9c8150bd43f908a0371625daba3d935ed1d90 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Tue, 21 Feb 2023 01:11:18 +0300 Subject: [PATCH 06/10] add forgotten set_close_reason --- aiormq/connection.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/aiormq/connection.py b/aiormq/connection.py index 8ecd195..8280f50 100644 --- a/aiormq/connection.py +++ b/aiormq/connection.py @@ -533,7 +533,14 @@ async def __handle_heartbeat(self, _: Heartbeat) -> None: return async def __handle_close(self, frame: spec.Connection.Close) -> None: - log.exception( + self.set_close_reason( + frame.reply_code or -1, + frame.reply_text or "", + frame.class_id or -1, + frame.method_id or -1, + ) + + log.error( "Unexpected connection close from remote \"%s\", " "Connection.Close(reply_code=%r, reply_text=%r)", self, frame.reply_code, frame.reply_text, From cb91bb356262e651dc8a2a8230fd16515e87ea74 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Tue, 21 Feb 2023 01:21:14 +0300 Subject: [PATCH 07/10] add forgotten set_close_reason for channel --- aiormq/channel.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/aiormq/channel.py b/aiormq/channel.py index 02b7274..d91749b 100644 --- a/aiormq/channel.py +++ b/aiormq/channel.py @@ -398,6 +398,12 @@ async def _on_cancel_frame( async def _on_close_frame(self, frame: spec.Channel.Close) -> None: exc: BaseException = exception_by_code(frame) + self.set_close_reason( + frame.reply_code or -1, + frame.reply_text or "", + frame.class_id or -1, + frame.method_id or -1 + ) with suppress(asyncio.QueueFull): self.write_queue.put_nowait( ChannelFrame.marshall( From 7c144fda5d1f2c93c1b06fc2feaa1ed3ea9e0ae4 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Tue, 21 Feb 2023 01:26:22 +0300 Subject: [PATCH 08/10] Remove awaitable --- aiormq/channel.py | 6 +++--- aiormq/tools.py | 27 +-------------------------- tests/test_tools.py | 7 +------ 3 files changed, 5 insertions(+), 35 deletions(-) diff --git a/aiormq/channel.py b/aiormq/channel.py index d91749b..4aed81d 100644 --- a/aiormq/channel.py +++ b/aiormq/channel.py @@ -21,7 +21,7 @@ from pamqp.exceptions import AMQPFrameError from pamqp.header import ContentHeader -from aiormq.tools import Countdown, awaitable +from aiormq.tools import Countdown from .abc import ( AbstractChannel, AbstractConnection, ArgumentsType, ChannelFrame, @@ -402,7 +402,7 @@ async def _on_close_frame(self, frame: spec.Channel.Close) -> None: frame.reply_code or -1, frame.reply_text or "", frame.class_id or -1, - frame.method_id or -1 + frame.method_id or -1, ) with suppress(asyncio.QueueFull): self.write_queue.put_nowait( @@ -526,7 +526,7 @@ async def basic_consume( if consumer_tag in self.consumers: raise DuplicateConsumerTag(self.number) - self.consumers[consumer_tag] = awaitable(consumer_callback) + self.consumers[consumer_tag] = consumer_callback return await self.rpc( spec.Basic.Consume( diff --git a/aiormq/tools.py b/aiormq/tools.py index b6b41e2..5927a84 100644 --- a/aiormq/tools.py +++ b/aiormq/tools.py @@ -1,12 +1,8 @@ import asyncio import platform import time -from functools import wraps from types import TracebackType -from typing import ( - Any, AsyncContextManager, Awaitable, Callable, Coroutine, Optional, Type, - TypeVar, Union, -) +from typing import Any, AsyncContextManager, Awaitable, Optional, Type, TypeVar from yarl import URL @@ -22,27 +18,6 @@ def censor_url(url: URL) -> URL: return url -def awaitable( - func: Callable[..., Union[T, Awaitable[T]]], -) -> Callable[..., Coroutine[Any, Any, T]]: - # Avoid python 3.8+ warning - if asyncio.iscoroutinefunction(func): - return func # type: ignore - - @wraps(func) - async def wrap(*args: Any, **kwargs: Any) -> T: - result = func(*args, **kwargs) - - if hasattr(result, "__await__"): - return await result # type: ignore - if asyncio.iscoroutine(result) or asyncio.isfuture(result): - return await result - - return result # type: ignore - - return wrap - - class Countdown: __slots__ = "loop", "deadline" diff --git a/tests/test_tools.py b/tests/test_tools.py index b5fba9f..d040605 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -2,7 +2,7 @@ import pytest -from aiormq.tools import Countdown, awaitable +from aiormq.tools import Countdown def simple_func(): @@ -43,11 +43,6 @@ def return_coroutine(): ] -@pytest.mark.parametrize("func,result", AWAITABLE_FUNCS) -async def test_awaitable(func, result, event_loop): - assert await awaitable(func)() == result - - async def test_countdown(event_loop): countdown = Countdown(timeout=0.1) await countdown(asyncio.sleep(0)) From dd65b94b024ec9fb6d47705a69b2131365372890 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Tue, 21 Feb 2023 01:54:29 +0300 Subject: [PATCH 09/10] remove tools --- aiormq/base.py | 7 +- aiormq/channel.py | 374 ++++++++++++++++++++++--------------------- aiormq/connection.py | 28 ++-- aiormq/tools.py | 85 ---------- tests/test_tools.py | 19 --- 5 files changed, 209 insertions(+), 304 deletions(-) delete mode 100644 aiormq/tools.py diff --git a/aiormq/base.py b/aiormq/base.py index dba5562..a36637a 100644 --- a/aiormq/base.py +++ b/aiormq/base.py @@ -9,7 +9,6 @@ AbstractBase, AbstractFutureStore, CoroutineType, ExceptionType, TimeoutType, ) -from .tools import Countdown T = TypeVar("T") @@ -64,6 +63,8 @@ async def wait_rejected() -> None: return self.loop.create_task(wait_rejected()) async def __task_wrapper(self, coro: CoroutineType) -> Any: + if coro is None: + return try: return await coro except asyncio.CancelledError as e: @@ -144,9 +145,7 @@ async def close( ) -> None: if self.is_closed: return None - - countdown = Countdown(timeout) - await countdown(self.__closer(exc)) + await asyncio.wait_for(self.__closer(exc), timeout=timeout) def __repr__(self) -> str: cls_name = self.__class__.__name__ diff --git a/aiormq/channel.py b/aiormq/channel.py index 4aed81d..64b2b2c 100644 --- a/aiormq/channel.py +++ b/aiormq/channel.py @@ -21,8 +21,6 @@ from pamqp.exceptions import AMQPFrameError from pamqp.header import ContentHeader -from aiormq.tools import Countdown - from .abc import ( AbstractChannel, AbstractConnection, ArgumentsType, ChannelFrame, ConfirmationFrameType, ConsumerCallback, DeliveredMessage, ExceptionType, @@ -155,31 +153,26 @@ def __str__(self) -> str: return str(self.number) @task - async def rpc( - self, frame: Frame, timeout: TimeoutType = None, - ) -> RpcReturnType: + async def rpc(self, frame: Frame) -> RpcReturnType: if self.__close_event.is_set(): raise ChannelInvalidStateError("Channel closed by RPC timeout") - countdown = Countdown(timeout) lock = self.lock - async with countdown.enter_context(lock): + async with lock: try: - await countdown( - self.write_queue.put( - ChannelFrame.marshall( - channel_number=self.number, - frames=[frame], - ), + await self.write_queue.put( + ChannelFrame.marshall( + channel_number=self.number, + frames=[frame], ), ) if not (frame.synchronous or getattr(frame, "nowait", False)): return None - result = await countdown(self.rpc_frames.get()) + result = await self.rpc_frames.get() self.rpc_frames.task_done() @@ -216,16 +209,18 @@ async def rpc( raise async def open(self, timeout: TimeoutType = None) -> spec.Channel.OpenOk: - frame: spec.Channel.OpenOk = await self.rpc( - spec.Channel.Open(), timeout=timeout, - ) + async def opener() -> spec.Channel.OpenOk: + frame: spec.Channel.OpenOk = await self.rpc( + spec.Channel.Open(), + ) - if self.publisher_confirms: - await self.rpc(spec.Confirm.Select()) + if self.publisher_confirms: + await self.rpc(spec.Confirm.Select()) - if frame is None: # pragma: no cover - raise AMQPFrameError(frame) - return frame + if frame is None: # pragma: no cover + raise AMQPFrameError(frame) + return frame + return await asyncio.wait_for(opener(), timeout=timeout) async def __get_content_frame(self) -> ContentBody: content_frame = await self._get_frame() @@ -464,11 +459,13 @@ async def _on_close(self, exc: Optional[ExceptionType] = None) -> None: if not self.connection.is_opened or self.__close_event.is_set(): return - await self.rpc( - spec.Channel.Close( - reply_code=self.__close_reply_code, - class_id=self.__close_class_id, - method_id=self.__close_method_id, + await asyncio.wait_for( + self.rpc( + spec.Channel.Close( + reply_code=self.__close_reply_code, + class_id=self.__close_class_id, + method_id=self.__close_method_id, + ), ), timeout=self.connection.connection_tune.heartbeat or None, ) @@ -479,31 +476,25 @@ async def basic_get( self, queue: str = "", no_ack: bool = False, timeout: TimeoutType = None, ) -> DeliveredMessage: - - countdown = Countdown(timeout) - async with countdown.enter_context(self.getter_lock): - self.getter = self.create_future() - - await self.rpc( - spec.Basic.Get(queue=queue, no_ack=no_ack), - timeout=countdown.get_timeout(), - ) - - frame: Union[spec.Basic.GetEmpty, spec.Basic.GetOk] - message: DeliveredMessage - - frame, message = await countdown(self.getter) - del self.getter - - return message + async def getter() -> DeliveredMessage: + async with self.getter_lock: + self.getter = self.create_future() + await self.rpc(spec.Basic.Get(queue=queue, no_ack=no_ack)) + frame: Union[spec.Basic.GetEmpty, spec.Basic.GetOk] + message: DeliveredMessage + frame, message = await self.getter + del self.getter + return message + return await asyncio.wait_for(getter(), timeout=timeout) async def basic_cancel( self, consumer_tag: str, *, nowait: bool = False, timeout: TimeoutType = None, ) -> spec.Basic.CancelOk: - return await self.rpc( - spec.Basic.Cancel(consumer_tag=consumer_tag, nowait=nowait), - timeout=timeout, + return await asyncio.wait_for( + self.rpc( + spec.Basic.Cancel(consumer_tag=consumer_tag, nowait=nowait), + ), timeout=timeout, ) async def basic_consume( @@ -528,15 +519,16 @@ async def basic_consume( self.consumers[consumer_tag] = consumer_callback - return await self.rpc( - spec.Basic.Consume( - queue=queue, - no_ack=no_ack, - exclusive=exclusive, - consumer_tag=consumer_tag, - arguments=arguments, - ), - timeout=timeout, + return await asyncio.wait_for( + self.rpc( + spec.Basic.Consume( + queue=queue, + no_ack=no_ack, + exclusive=exclusive, + consumer_tag=consumer_tag, + arguments=arguments, + ), + ), timeout=timeout, ) async def basic_ack( @@ -634,7 +626,6 @@ async def basic_publish( ) -> Optional[ConfirmationFrameType]: _check_routing_key(routing_key) drain_future = self.create_future() if wait else None - countdown = Countdown(timeout=timeout) publish_frame = spec.Basic.Publish( exchange=exchange, @@ -648,60 +639,60 @@ async def basic_publish( body_size=len(body), ) - if not content_header.properties.message_id: - # UUID compatible random bytes - rnd_uuid = UUID(int=getrandbits(128), version=4) - content_header.properties.message_id = rnd_uuid.hex + async def publisher() -> Optional[ConfirmationFrameType]: + if not content_header.properties.message_id: + # UUID compatible random bytes + rnd_uuid = UUID(int=getrandbits(128), version=4) + content_header.properties.message_id = rnd_uuid.hex - confirmation: Optional[ConfirmationType] = None + confirmation: Optional[ConfirmationType] = None - async with countdown.enter_context(self.lock): - self.delivery_tag += 1 + async with self.lock: + self.delivery_tag += 1 - if self.publisher_confirms: - message_id = content_header.properties.message_id + if self.publisher_confirms: + message_id = content_header.properties.message_id - if self.delivery_tag not in self.confirmations: - self.confirmations[ - self.delivery_tag - ] = self.create_future() + if self.delivery_tag not in self.confirmations: + self.confirmations[ + self.delivery_tag + ] = self.create_future() - confirmation = self.confirmations[self.delivery_tag] - self.message_id_delivery_tag[message_id] = self.delivery_tag + confirmation = self.confirmations[self.delivery_tag] + self.message_id_delivery_tag[message_id] = self.delivery_tag - if confirmation is None: - return + if confirmation is None: + return - confirmation.add_done_callback( - lambda _: self.message_id_delivery_tag.pop( - message_id, None, - ), - ) + confirmation.add_done_callback( + lambda _: self.message_id_delivery_tag.pop( + message_id, None, + ), + ) - body_frames: List[Union[FrameType, ContentBody]] - body_frames = [publish_frame, content_header] - body_frames += self._split_body(body) + body_frames: List[Union[FrameType, ContentBody]] + body_frames = [publish_frame, content_header] + body_frames += self._split_body(body) - await countdown( - self.write_queue.put( + await self.write_queue.put( ChannelFrame.marshall( frames=body_frames, channel_number=self.number, drain_future=drain_future, ), - ), - ) + ) - if drain_future: - await drain_future + if drain_future: + await drain_future - if not self.publisher_confirms: - return None + if not self.publisher_confirms: + return None - if confirmation is None: - return None + if confirmation is None: + return None - return await countdown(confirmation) + return await confirmation + return await asyncio.wait_for(publisher(), timeout=timeout) async def basic_qos( self, @@ -711,13 +702,14 @@ async def basic_qos( global_: bool = False, timeout: TimeoutType = None, ) -> spec.Basic.QosOk: - return await self.rpc( - spec.Basic.Qos( - prefetch_size=prefetch_size or 0, - prefetch_count=prefetch_count or 0, - global_=global_, - ), - timeout=timeout, + return await asyncio.wait_for( + self.rpc( + spec.Basic.Qos( + prefetch_size=prefetch_size or 0, + prefetch_count=prefetch_count or 0, + global_=global_, + ), + ), timeout=timeout, ) async def basic_recover( @@ -730,7 +722,9 @@ async def basic_recover( else: frame = spec.Basic.Recover(requeue=requeue) - return await self.rpc(frame, timeout=timeout) + return await asyncio.wait_for( + self.rpc(frame), timeout=timeout, + ) async def exchange_declare( self, @@ -745,18 +739,19 @@ async def exchange_declare( arguments: Optional[Dict[str, Any]] = None, timeout: TimeoutType = None, ) -> spec.Exchange.DeclareOk: - return await self.rpc( - spec.Exchange.Declare( - exchange=str(exchange), - exchange_type=str(exchange_type), - passive=bool(passive), - durable=bool(durable), - auto_delete=bool(auto_delete), - internal=bool(internal), - nowait=bool(nowait), - arguments=arguments, - ), - timeout=timeout, + return await asyncio.wait_for( + self.rpc( + spec.Exchange.Declare( + exchange=str(exchange), + exchange_type=str(exchange_type), + passive=bool(passive), + durable=bool(durable), + auto_delete=bool(auto_delete), + internal=bool(internal), + nowait=bool(nowait), + arguments=arguments, + ), + ), timeout=timeout, ) async def exchange_delete( @@ -767,11 +762,12 @@ async def exchange_delete( nowait: bool = False, timeout: TimeoutType = None, ) -> spec.Exchange.DeleteOk: - return await self.rpc( - spec.Exchange.Delete( - exchange=exchange, nowait=nowait, if_unused=if_unused, - ), - timeout=timeout, + return await asyncio.wait_for( + self.rpc( + spec.Exchange.Delete( + exchange=exchange, nowait=nowait, if_unused=if_unused, + ), + ), timeout=timeout, ) async def exchange_bind( @@ -785,15 +781,16 @@ async def exchange_bind( timeout: TimeoutType = None, ) -> spec.Exchange.BindOk: _check_routing_key(routing_key) - return await self.rpc( - spec.Exchange.Bind( - destination=destination, - source=source, - routing_key=routing_key, - nowait=nowait, - arguments=arguments, - ), - timeout=timeout, + return await asyncio.wait_for( + self.rpc( + spec.Exchange.Bind( + destination=destination, + source=source, + routing_key=routing_key, + nowait=nowait, + arguments=arguments, + ), + ), timeout=timeout, ) async def exchange_unbind( @@ -807,24 +804,26 @@ async def exchange_unbind( timeout: TimeoutType = None, ) -> spec.Exchange.UnbindOk: _check_routing_key(routing_key) - return await self.rpc( - spec.Exchange.Unbind( - destination=destination, - source=source, - routing_key=routing_key, - nowait=nowait, - arguments=arguments, - ), - timeout=timeout, + return await asyncio.wait_for( + self.rpc( + spec.Exchange.Unbind( + destination=destination, + source=source, + routing_key=routing_key, + nowait=nowait, + arguments=arguments, + ), + ), timeout=timeout, ) async def flow( self, active: bool, timeout: TimeoutType = None, ) -> spec.Channel.FlowOk: - return await self.rpc( - spec.Channel.Flow(active=active), - timeout=timeout, + return await asyncio.wait_for( + self.rpc( + spec.Channel.Flow(active=active), + ), timeout=timeout, ) async def queue_bind( @@ -837,15 +836,16 @@ async def queue_bind( timeout: TimeoutType = None, ) -> spec.Queue.BindOk: _check_routing_key(routing_key) - return await self.rpc( - spec.Queue.Bind( - queue=queue, - exchange=exchange, - routing_key=routing_key, - nowait=nowait, - arguments=arguments, - ), - timeout=timeout, + return await asyncio.wait_for( + self.rpc( + spec.Queue.Bind( + queue=queue, + exchange=exchange, + routing_key=routing_key, + nowait=nowait, + arguments=arguments, + ), + ), timeout=timeout, ) async def queue_declare( @@ -860,17 +860,18 @@ async def queue_declare( arguments: Optional[ArgumentsType] = None, timeout: TimeoutType = None, ) -> spec.Queue.DeclareOk: - return await self.rpc( - spec.Queue.Declare( - queue=queue, - passive=bool(passive), - durable=bool(durable), - exclusive=bool(exclusive), - auto_delete=bool(auto_delete), - nowait=bool(nowait), - arguments=arguments, - ), - timeout=timeout, + return await asyncio.wait_for( + self.rpc( + spec.Queue.Declare( + queue=queue, + passive=bool(passive), + durable=bool(durable), + exclusive=bool(exclusive), + auto_delete=bool(auto_delete), + nowait=bool(nowait), + arguments=arguments, + ), + ), timeout=timeout, ) async def queue_delete( @@ -881,22 +882,23 @@ async def queue_delete( nowait: bool = False, timeout: TimeoutType = None, ) -> spec.Queue.DeleteOk: - return await self.rpc( - spec.Queue.Delete( - queue=queue, - if_unused=if_unused, - if_empty=if_empty, - nowait=nowait, - ), - timeout=timeout, + return await asyncio.wait_for( + self.rpc( + spec.Queue.Delete( + queue=queue, + if_unused=if_unused, + if_empty=if_empty, + nowait=nowait, + ), + ), timeout=timeout, ) async def queue_purge( self, queue: str = "", nowait: bool = False, timeout: TimeoutType = None, ) -> spec.Queue.PurgeOk: - return await self.rpc( - spec.Queue.Purge(queue=queue, nowait=nowait), + return await asyncio.wait_for( + self.rpc(spec.Queue.Purge(queue=queue, nowait=nowait)), timeout=timeout, ) @@ -909,34 +911,40 @@ async def queue_unbind( timeout: TimeoutType = None, ) -> spec.Queue.UnbindOk: _check_routing_key(routing_key) - return await self.rpc( - spec.Queue.Unbind( - routing_key=routing_key, - arguments=arguments, - queue=queue, - exchange=exchange, - ), - timeout=timeout, + return await asyncio.wait_for( + self.rpc( + spec.Queue.Unbind( + routing_key=routing_key, + arguments=arguments, + queue=queue, + exchange=exchange, + ), + ), timeout=timeout, ) async def tx_commit( self, timeout: TimeoutType = None, ) -> spec.Tx.CommitOk: - return await self.rpc(spec.Tx.Commit(), timeout=timeout) + return await asyncio.wait_for( + self.rpc(spec.Tx.Commit()), timeout=timeout, + ) async def tx_rollback( self, timeout: TimeoutType = None, ) -> spec.Tx.RollbackOk: - return await self.rpc(spec.Tx.Rollback(), timeout=timeout) + return await asyncio.wait_for( + self.rpc(spec.Tx.Rollback()), timeout=timeout, + ) async def tx_select(self, timeout: TimeoutType = None) -> spec.Tx.SelectOk: - return await self.rpc(spec.Tx.Select(), timeout=timeout) + return await asyncio.wait_for( + self.rpc(spec.Tx.Select()), timeout=timeout, + ) async def confirm_delivery( self, nowait: bool = False, timeout: TimeoutType = None, ) -> spec.Confirm.SelectOk: - return await self.rpc( - spec.Confirm.Select(nowait=nowait), - timeout=timeout, + return await asyncio.wait_for( + self.rpc(spec.Confirm.Select(nowait=nowait)), timeout=timeout, ) diff --git a/aiormq/connection.py b/aiormq/connection.py index 8280f50..b7f4dba 100644 --- a/aiormq/connection.py +++ b/aiormq/connection.py @@ -37,7 +37,6 @@ ConnectionResourceError, ConnectionSyntaxError, ConnectionUnexpectedFrame, IncompatibleProtocolError, ProbableAuthenticationError, ) -from .tools import Countdown, censor_url # noinspection PyUnresolvedReferences @@ -331,7 +330,9 @@ def is_opened(self) -> bool: ) def __str__(self) -> str: - return str(censor_url(self.url)) + if self.url.password is not None: + return str(self.url.with_password("******")) + return str(self.url) def _get_ssl_context(self) -> ssl.SSLContext: context = ssl.create_default_context( @@ -862,18 +863,19 @@ async def update_secret( ], ) - countdown = Countdown(timeout) + async def updater() -> spec.Connection.UpdateSecretOk: + async with self.__update_secret_lock: + self.__update_secret_future = self.loop.create_future() + await self.write_queue.put(channel_frame) + try: + response: spec.Connection.UpdateSecretOk = ( + await self.__update_secret_future + ) + finally: + self.__update_secret_future = None + return response - async with countdown.enter_context(self.__update_secret_lock): - self.__update_secret_future = self.loop.create_future() - await self.write_queue.put(channel_frame) - try: - response: spec.Connection.UpdateSecretOk = ( - await countdown(self.__update_secret_future) - ) - finally: - self.__update_secret_future = None - return response + return await asyncio.wait_for(updater(), timeout=timeout) async def __aenter__(self) -> AbstractConnection: if not self.is_opened: diff --git a/aiormq/tools.py b/aiormq/tools.py deleted file mode 100644 index 5927a84..0000000 --- a/aiormq/tools.py +++ /dev/null @@ -1,85 +0,0 @@ -import asyncio -import platform -import time -from types import TracebackType -from typing import Any, AsyncContextManager, Awaitable, Optional, Type, TypeVar - -from yarl import URL - -from aiormq.abc import TimeoutType - - -T = TypeVar("T") - - -def censor_url(url: URL) -> URL: - if url.password is not None: - return url.with_password("******") - return url - - -class Countdown: - __slots__ = "loop", "deadline" - - if platform.system() == "Windows": - @staticmethod - def _now() -> float: - # windows monotonic timer resolution is not enough. - # Have to use time.time() - return time.time() - else: - @staticmethod - def _now() -> float: - return time.monotonic() - - def __init__(self, timeout: TimeoutType = None): - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() - self.deadline: TimeoutType = None - - if timeout is not None: - self.deadline = self._now() + timeout - - def get_timeout(self) -> TimeoutType: - if self.deadline is None: - return None - - current = self._now() - if current >= self.deadline: - raise asyncio.TimeoutError - - return self.deadline - current - - async def __call__(self, coro: Awaitable[T]) -> T: - try: - timeout = self.get_timeout() - except asyncio.TimeoutError: - fut = asyncio.ensure_future(coro) - fut.cancel() - await asyncio.gather(fut, return_exceptions=True) - raise - - if self.deadline is None and not timeout: - return await coro - return await asyncio.wait_for(coro, timeout=timeout) - - def enter_context( - self, ctx: AsyncContextManager[T], - ) -> AsyncContextManager[T]: - return CountdownContext(self, ctx) - - -class CountdownContext(AsyncContextManager): - def __init__(self, countdown: Countdown, ctx: AsyncContextManager): - self.countdown: Countdown = countdown - self.ctx: AsyncContextManager = ctx - - async def __aenter__(self) -> T: - return await self.countdown(self.ctx.__aenter__()) - - async def __aexit__( - self, exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], - ) -> Any: - return await self.countdown( - self.ctx.__aexit__(exc_type, exc_val, exc_tb), - ) diff --git a/tests/test_tools.py b/tests/test_tools.py index d040605..ca382f1 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,9 +1,5 @@ import asyncio -import pytest - -from aiormq.tools import Countdown - def simple_func(): return 1 @@ -41,18 +37,3 @@ def return_coroutine(): (await_future, 6), (return_coroutine, 6), ] - - -async def test_countdown(event_loop): - countdown = Countdown(timeout=0.1) - await countdown(asyncio.sleep(0)) - - # waiting for the countdown exceeded - await asyncio.sleep(0.2) - - task = asyncio.create_task(asyncio.sleep(0)) - - with pytest.raises(asyncio.TimeoutError): - await countdown(task) - - assert task.cancelled() From 2f3cfdc8b381cd307e3dac6254313735c59613e4 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Sun, 5 Mar 2023 00:14:57 +0300 Subject: [PATCH 10/10] WIP --- aiormq/abc.py | 179 +++++++++++------------------------- aiormq/base.py | 183 ++++++++----------------------------- aiormq/channel.py | 71 +++++++------- aiormq/connection.py | 150 ++++++++++++++---------------- tests/test_future_store.py | 88 ------------------ 5 files changed, 194 insertions(+), 477 deletions(-) delete mode 100644 tests/test_future_store.py diff --git a/aiormq/abc.py b/aiormq/abc.py index dd12e7e..0b86ac5 100644 --- a/aiormq/abc.py +++ b/aiormq/abc.py @@ -203,62 +203,16 @@ def marshall( ) -class AbstractFutureStore: - futures: Set[asyncio.Future] - loop: asyncio.AbstractEventLoop - - @abstractmethod - def add(self, future: asyncio.Future) -> None: - raise NotImplementedError - - @abstractmethod - def reject_all(self, exception: Optional[ExceptionType]) -> Any: - raise NotImplementedError - - @abstractmethod - def create_task(self, coro: CoroutineType) -> asyncio.Task: - raise NotImplementedError - - @abstractmethod - def create_future(self) -> asyncio.Future: - raise NotImplementedError - - @abstractmethod - def get_child(self) -> "AbstractFutureStore": - raise NotImplementedError - - class AbstractBase(ABC): loop: asyncio.AbstractEventLoop @abstractmethod - def _future_store_child(self) -> AbstractFutureStore: - raise NotImplementedError - - @abstractmethod - def create_task(self, coro: CoroutineType) -> asyncio.Future: - raise NotImplementedError - - def create_future(self) -> asyncio.Future: - raise NotImplementedError + def create_task(self, coro: CoroutineType) -> asyncio.Future: ... - @abstractmethod - async def _on_close(self, exc: Optional[Exception] = None) -> None: - raise NotImplementedError + def create_future(self) -> asyncio.Future: ... @abstractmethod - async def close( - self, exc: Optional[ExceptionType] = asyncio.CancelledError(), - ) -> None: - raise NotImplementedError - - @abstractmethod - def __str__(self) -> str: - raise NotImplementedError - - @abstractproperty - def is_closed(self) -> bool: - raise NotImplementedError + def close(self, exc: Optional[ExceptionType] = None) -> Awaitable[Any]: ... class AbstractChannel(AbstractBase): @@ -269,22 +223,24 @@ class AbstractChannel(AbstractBase): closing: asyncio.Future @abstractmethod - async def open(self) -> spec.Channel.OpenOk: - pass + async def close( + self, exc: Optional[BaseException] = None + ) -> Optional[spec.Channel.CloseOk]: ... + + @abstractmethod + async def open(self) -> spec.Channel.OpenOk: ... @abstractmethod async def basic_get( self, queue: str = "", no_ack: bool = False, timeout: TimeoutType = None, - ) -> DeliveredMessage: - raise NotImplementedError + ) -> DeliveredMessage: ... @abstractmethod async def basic_cancel( self, consumer_tag: str, *, nowait: bool = False, timeout: TimeoutType = None, - ) -> spec.Basic.CancelOk: - raise NotImplementedError + ) -> spec.Basic.CancelOk: ... @abstractmethod async def basic_consume( @@ -297,14 +253,12 @@ async def basic_consume( arguments: Optional[ArgumentsType] = None, consumer_tag: Optional[str] = None, timeout: TimeoutType = None, - ) -> spec.Basic.ConsumeOk: - raise NotImplementedError + ) -> spec.Basic.ConsumeOk: ... @abstractmethod def basic_ack( self, delivery_tag: int, multiple: bool = False, wait: bool = True, - ) -> DrainResult: - raise NotImplementedError + ) -> DrainResult: ... @abstractmethod def basic_nack( @@ -313,14 +267,12 @@ def basic_nack( multiple: bool = False, requeue: bool = True, wait: bool = True, - ) -> DrainResult: - raise NotImplementedError + ) -> DrainResult: ... @abstractmethod def basic_reject( self, delivery_tag: int, *, requeue: bool = True, wait: bool = True, - ) -> DrainResult: - raise NotImplementedError + ) -> DrainResult: ... @abstractmethod async def basic_publish( @@ -333,8 +285,7 @@ async def basic_publish( mandatory: bool = False, immediate: bool = False, timeout: TimeoutType = None, - ) -> Optional[ConfirmationFrameType]: - raise NotImplementedError + ) -> Optional[ConfirmationFrameType]: ... @abstractmethod async def basic_qos( @@ -344,15 +295,13 @@ async def basic_qos( prefetch_count: Optional[int] = None, global_: bool = False, timeout: TimeoutType = None, - ) -> spec.Basic.QosOk: - raise NotImplementedError + ) -> spec.Basic.QosOk: ... @abstractmethod async def basic_recover( self, *, nowait: bool = False, requeue: bool = False, timeout: TimeoutType = None, - ) -> spec.Basic.RecoverOk: - raise NotImplementedError + ) -> spec.Basic.RecoverOk: ... @abstractmethod async def exchange_declare( @@ -367,8 +316,7 @@ async def exchange_declare( nowait: bool = False, arguments: Optional[ArgumentsType] = None, timeout: TimeoutType = None, - ) -> spec.Exchange.DeclareOk: - raise NotImplementedError + ) -> spec.Exchange.DeclareOk: ... @abstractmethod async def exchange_delete( @@ -378,8 +326,7 @@ async def exchange_delete( if_unused: bool = False, nowait: bool = False, timeout: TimeoutType = None, - ) -> spec.Exchange.DeleteOk: - raise NotImplementedError + ) -> spec.Exchange.DeleteOk: ... @abstractmethod async def exchange_bind( @@ -391,8 +338,7 @@ async def exchange_bind( nowait: bool = False, arguments: Optional[ArgumentsType] = None, timeout: TimeoutType = None, - ) -> spec.Exchange.BindOk: - raise NotImplementedError + ) -> spec.Exchange.BindOk: ... @abstractmethod async def exchange_unbind( @@ -404,15 +350,13 @@ async def exchange_unbind( nowait: bool = False, arguments: Optional[ArgumentsType] = None, timeout: TimeoutType = None, - ) -> spec.Exchange.UnbindOk: - raise NotImplementedError + ) -> spec.Exchange.UnbindOk: ... @abstractmethod async def flow( self, active: bool, timeout: TimeoutType = None, - ) -> spec.Channel.FlowOk: - raise NotImplementedError + ) -> spec.Channel.FlowOk: ... @abstractmethod async def queue_bind( @@ -423,8 +367,7 @@ async def queue_bind( nowait: bool = False, arguments: Optional[ArgumentsType] = None, timeout: TimeoutType = None, - ) -> spec.Queue.BindOk: - raise NotImplementedError + ) -> spec.Queue.BindOk: ... @abstractmethod async def queue_declare( @@ -438,8 +381,7 @@ async def queue_declare( nowait: bool = False, arguments: Optional[ArgumentsType] = None, timeout: TimeoutType = None, - ) -> spec.Queue.DeclareOk: - raise NotImplementedError + ) -> spec.Queue.DeclareOk: ... @abstractmethod async def queue_delete( @@ -449,15 +391,13 @@ async def queue_delete( if_empty: bool = False, nowait: bool = False, timeout: TimeoutType = None, - ) -> spec.Queue.DeleteOk: - raise NotImplementedError + ) -> spec.Queue.DeleteOk: ... @abstractmethod async def queue_purge( self, queue: str = "", nowait: bool = False, timeout: TimeoutType = None, - ) -> spec.Queue.PurgeOk: - raise NotImplementedError + ) -> spec.Queue.PurgeOk: ... @abstractmethod async def queue_unbind( @@ -467,31 +407,28 @@ async def queue_unbind( routing_key: str = "", arguments: Optional[ArgumentsType] = None, timeout: TimeoutType = None, - ) -> spec.Queue.UnbindOk: - raise NotImplementedError + ) -> spec.Queue.UnbindOk: ... @abstractmethod async def tx_commit( self, timeout: TimeoutType = None, - ) -> spec.Tx.CommitOk: - raise NotImplementedError + ) -> spec.Tx.CommitOk: ... @abstractmethod async def tx_rollback( self, timeout: TimeoutType = None, - ) -> spec.Tx.RollbackOk: - raise NotImplementedError + ) -> spec.Tx.RollbackOk: ... @abstractmethod - async def tx_select(self, timeout: TimeoutType = None) -> spec.Tx.SelectOk: - raise NotImplementedError + async def tx_select( + self, timeout: TimeoutType = None + ) -> spec.Tx.SelectOk: ... @abstractmethod async def confirm_delivery( self, nowait: bool = False, timeout: TimeoutType = None, - ) -> spec.Confirm.SelectOk: - raise NotImplementedError + ) -> spec.Confirm.SelectOk: ... class AbstractConnection(AbstractBase): @@ -502,6 +439,7 @@ class AbstractConnection(AbstractBase): # Allow three missed heartbeats (based on heartbeat(timeout) HEARTBEAT_GRACE_MULTIPLIER: int + loop: asyncio.AbstractEventLoop server_properties: ArgumentsType connection_tune: spec.Connection.Tune channels: Dict[int, Optional[AbstractChannel]] @@ -514,42 +452,36 @@ def set_close_reason( self, reply_code: int = REPLY_SUCCESS, reply_text: str = "normally closed", class_id: int = 0, method_id: int = 0, - ) -> None: - raise NotImplementedError + ) -> None: ... @abstractproperty - def is_opened(self) -> bool: - raise NotImplementedError + def is_opened(self) -> bool: ... + + @abstractmethod + def create_task(self, coro: CoroutineType) -> asyncio.Task: ... @abstractmethod - def __str__(self) -> str: - raise NotImplementedError + def __str__(self) -> str: ... @abstractmethod async def connect( self, client_properties: Optional[FieldTable] = None, - ) -> bool: - raise NotImplementedError + ) -> bool: ... @abstractproperty - def server_capabilities(self) -> ArgumentsType: - raise NotImplementedError + def server_capabilities(self) -> ArgumentsType: ... @abstractproperty - def basic_nack(self) -> bool: - raise NotImplementedError + def basic_nack(self) -> bool: ... @abstractproperty - def consumer_cancel_notify(self) -> bool: - raise NotImplementedError + def consumer_cancel_notify(self) -> bool: ... @abstractproperty - def exchange_exchange_bindings(self) -> bool: - raise NotImplementedError + def exchange_exchange_bindings(self) -> bool: ... @abstractproperty - def publisher_confirms(self) -> Optional[bool]: - raise NotImplementedError + def publisher_confirms(self) -> Optional[bool]: ... async def channel( self, @@ -558,12 +490,10 @@ async def channel( frame_buffer_size: int = FRAME_BUFFER_SIZE, timeout: TimeoutType = None, **kwargs: Any, - ) -> AbstractChannel: - raise NotImplementedError + ) -> AbstractChannel: ... @abstractmethod - async def __aenter__(self) -> "AbstractConnection": - raise NotImplementedError + async def __aenter__(self) -> "AbstractConnection": ... @abstractmethod async def __aexit__( @@ -571,24 +501,21 @@ async def __aexit__( exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], - ) -> Optional[bool]: - raise NotImplementedError + ) -> Optional[bool]: ... @abstractmethod - async def ready(self) -> None: - raise NotImplementedError + async def ready(self) -> None: ... @abstractmethod async def update_secret( self, new_secret: str, *, reason: str = "", timeout: TimeoutType = None, - ) -> spec.Connection.UpdateSecretOk: - raise NotImplementedError + ) -> spec.Connection.UpdateSecretOk: ... __all__ = ( "AbstractBase", "AbstractChannel", "AbstractConnection", - "AbstractFutureStore", "ArgumentsType", "CallbackCoro", "ChannelFrame", + "ArgumentsType", "CallbackCoro", "ChannelFrame", "ChannelRType", "ConfirmationFrameType", "ConsumerCallback", "CoroutineType", "DeliveredMessage", "DrainResult", "ExceptionType", "FieldArray", "FieldTable", "FieldValue", "FrameReceived", "FrameType", diff --git a/aiormq/base.py b/aiormq/base.py index a36637a..4d77381 100644 --- a/aiormq/base.py +++ b/aiormq/base.py @@ -1,12 +1,8 @@ -import abc import asyncio -from contextlib import suppress -from functools import wraps -from typing import Any, Awaitable, Callable, Optional, Set, TypeVar -from weakref import WeakSet +from typing import Awaitable, Optional, Set, TypeVar, Any, Coroutine from .abc import ( - AbstractBase, AbstractFutureStore, CoroutineType, ExceptionType, + AbstractBase, CoroutineType, ExceptionType, TimeoutType, ) @@ -14,160 +10,53 @@ T = TypeVar("T") -class FutureStore(AbstractFutureStore): - __slots__ = ( - "futures", "loop", "parent", "__rejecting", - ) - - futures: Set[asyncio.Future] - weak_futures: WeakSet - loop: asyncio.AbstractEventLoop - - def __init__(self, loop: asyncio.AbstractEventLoop): - self.futures = set() - self.loop = loop - self.parent: Optional[FutureStore] = None - self.__rejecting: Optional[ExceptionType] = None - - def add(self, future: asyncio.Future) -> None: - self.futures.add(future) - future.add_done_callback(self.futures.discard) - if self.parent: - self.parent.add(future) - - def reject_all(self, exception: Optional[ExceptionType]) -> Awaitable[None]: - self.__rejecting = exception or RuntimeError("Has been rejected") - - tasks = [] - - while self.futures: - future: asyncio.Future = self.futures.pop() - - tasks.append(future) - - if future.done(): - continue - elif isinstance(future, asyncio.Task): - future.cancel() - elif isinstance(future, asyncio.Future): - future.set_exception(self.__rejecting) - - async def wait_rejected() -> None: - nonlocal tasks - try: - if not tasks: - return - await asyncio.gather(*tasks, return_exceptions=True) - finally: - self.__rejecting = None - return self.loop.create_task(wait_rejected()) +class Base(AbstractBase): + __slots__ = "loop", "closing", "__tasks" - async def __task_wrapper(self, coro: CoroutineType) -> Any: - if coro is None: - return - try: - return await coro - except asyncio.CancelledError as e: - if self.__rejecting is None: - raise - raise self.__rejecting from e + def __init__(self, *, loop: asyncio.AbstractEventLoop): + self.__tasks: Set[asyncio.Future] = set() + self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() + self.closing = self.create_future() def create_task(self, coro: CoroutineType) -> asyncio.Task: - task: asyncio.Task = self.loop.create_task(self.__task_wrapper(coro)) - self.add(task) + task = self.loop.create_task(coro) + task.add_done_callback(self.__tasks.discard) + self.__tasks.add(task) return task - def create_future(self, weak: bool = False) -> asyncio.Future: + def create_future(self) -> asyncio.Future: future = self.loop.create_future() - self.add(future) + future.add_done_callback(self.__tasks.discard) + self.__tasks.add(future) return future - def get_child(self) -> "FutureStore": - store = FutureStore(self.loop) - store.parent = self - return store - - -class Base(AbstractBase): - __slots__ = "loop", "__future_store", "closing" - - def __init__( - self, *, loop: asyncio.AbstractEventLoop, - parent: Optional[AbstractBase] = None, - ): - self.loop: asyncio.AbstractEventLoop = loop - - if parent: - self.__future_store = parent._future_store_child() - else: - self.__future_store = FutureStore(loop=self.loop) - - self.closing = self._create_closing_future() - - def _create_closing_future(self) -> asyncio.Future: - future = self.__future_store.create_future() - future.add_done_callback(lambda x: x.exception()) - return future - - def _cancel_tasks( - self, exc: Optional[ExceptionType] = None, - ) -> Awaitable[None]: - return self.__future_store.reject_all(exc) - - def _future_store_child(self) -> AbstractFutureStore: - return self.__future_store.get_child() - - def create_task(self, coro: CoroutineType) -> asyncio.Task: - return self.__future_store.create_task(coro) - - def create_future(self) -> asyncio.Future: - return self.__future_store.create_future() - - @abc.abstractmethod - async def _on_close( - self, exc: Optional[ExceptionType] = None, - ) -> None: # pragma: no cover - return - - async def __closer(self, exc: Optional[ExceptionType]) -> None: - if self.is_closed: # pragma: no cover - return - - with suppress(Exception): - await self._on_close(exc) - - with suppress(Exception): - await self._cancel_tasks(exc) - - async def close( - self, exc: Optional[ExceptionType] = asyncio.CancelledError, - timeout: TimeoutType = None, - ) -> None: - if self.is_closed: - return None - await asyncio.wait_for(self.__closer(exc), timeout=timeout) - - def __repr__(self) -> str: - cls_name = self.__class__.__name__ - return '<{0}: "{1}" at 0x{2:02x}>'.format( - cls_name, str(self), id(self), - ) - - @abc.abstractmethod - def __str__(self) -> str: # pragma: no cover - raise NotImplementedError - @property def is_closed(self) -> bool: return self.closing.done() + def close( + self, exc: Optional[BaseException] = asyncio.CancelledError, + timeout: TimeoutType = None, + ) -> Awaitable[Any]: + exc = exc or RuntimeError("Closed") + futures: Set[asyncio.Future] = set() -TaskFunctionType = Callable[..., T] + for future in self.__tasks: + futures.add(future) + if future.done(): + continue + elif isinstance(future, asyncio.Task): + future.cancel() + else: + future.set_exception(exc) -def task(func: TaskFunctionType) -> TaskFunctionType: - @wraps(func) - async def wrap(self: Base, *args: Any, **kwargs: Any) -> Any: - return await self.create_task(func(self, *args, **kwargs)) + async def closer(): + nonlocal futures + if not futures: + return + await asyncio.gather(*futures, return_exceptions=True) - return wrap + return self.loop.create_task( + asyncio.wait_for(closer(), timeout=timeout) + ) diff --git a/aiormq/channel.py b/aiormq/channel.py index 64b2b2c..d38b3ab 100644 --- a/aiormq/channel.py +++ b/aiormq/channel.py @@ -26,7 +26,7 @@ ConfirmationFrameType, ConsumerCallback, DeliveredMessage, ExceptionType, FrameType, GetResultType, ReturnCallback, RpcReturnType, TimeoutType, ) -from .base import Base, task +from .base import Base from .exceptions import ( AMQPChannelError, AMQPError, ChannelAccessRefused, ChannelClosed, ChannelInvalidStateError, ChannelLockedResource, ChannelNotFoundEntity, @@ -85,14 +85,11 @@ def __init__( on_return_raises: bool = True, ): - super().__init__(loop=connector.loop, parent=connector) - + super().__init__(loop=connector.loop) self.connection = connector - if ( - publisher_confirms and not connector.publisher_confirms - ): # pragma: no cover - raise ValueError("Server does't support publisher confirms") + if publisher_confirms and not connector.publisher_confirms: + raise ValueError("Server doesn't support publisher confirms") self.consumers: Dict[str, ConsumerCallback] = {} self.confirmations = OrderedDict() @@ -126,7 +123,6 @@ def __init__( self.__close_reply_text: str = "" self.__close_class_id: int = 0 self.__close_method_id: int = 0 - self.__close_event: asyncio.Event = asyncio.Event() def set_close_reason( self, reply_code: int = REPLY_SUCCESS, @@ -152,10 +148,9 @@ async def _get_frame(self) -> FrameType: def __str__(self) -> str: return str(self.number) - @task async def rpc(self, frame: Frame) -> RpcReturnType: - if self.__close_event.is_set(): + if self.is_closed: raise ChannelInvalidStateError("Channel closed by RPC timeout") lock = self.lock @@ -180,7 +175,7 @@ async def rpc(self, frame: Frame) -> RpcReturnType: raise InvalidFrameError(frame) return result - except (asyncio.CancelledError, asyncio.TimeoutError): + except (asyncio.CancelledError, asyncio.TimeoutError) as e: if self.is_closed: raise @@ -189,8 +184,12 @@ async def rpc(self, frame: Frame) -> RpcReturnType: self, frame, ) - self.__close_event.set() - await self.write_queue.put( + if self.closing.done(): + raise + + self.closing.set_exception(e) + + self.write_queue.put_nowait( ChannelFrame.marshall( channel_number=self.number, frames=[ @@ -407,12 +406,15 @@ async def _on_close_frame(self, frame: spec.Channel.Close) -> None: ), ) self.connection.channels.pop(self.number, None) - self.__close_event.set() + self.closing.set_exception(exception_by_code(frame)) raise exc - async def _on_close_ok_frame(self, _: spec.Channel.CloseOk) -> None: + async def _on_close_ok_frame(self, frame: spec.Channel.CloseOk) -> None: self.connection.channels.pop(self.number, None) - self.__close_event.set() + + if not self.closing.done(): + self.closing.set_result(frame) + await self.rpc_frames.put(frame) raise ChannelClosed(None, None) async def _reader(self) -> None: @@ -431,8 +433,6 @@ async def _reader(self) -> None: spec.Basic.Nack: (False, self._on_confirm_frame), } - last_exception: Optional[BaseException] = None - try: while True: frame = await self._get_frame() @@ -443,34 +443,33 @@ async def _reader(self) -> None: if should_add_to_rpc: await self.rpc_frames.put(frame) - except asyncio.CancelledError as e: - self.__close_event.set() - last_exception = e + except ChannelClosed: return - except Exception as e: - last_exception = e + except BaseException as e: + if not self.closing.done(): + self.closing.set_exception(e) raise - finally: - await self.close( - last_exception, timeout=self.CHANNEL_CLOSE_TIMEOUT, - ) - async def _on_close(self, exc: Optional[ExceptionType] = None) -> None: - if not self.connection.is_opened or self.__close_event.is_set(): - return + def close( + self, exc: Optional[BaseException] = None, timeout: TimeoutType = None + ) -> Awaitable[spec.Channel.CloseOk]: + super_close = super().close - await asyncio.wait_for( - self.rpc( + async def closer() -> spec.Channel.CloseOk: + result = await self.rpc( spec.Channel.Close( reply_code=self.__close_reply_code, class_id=self.__close_class_id, method_id=self.__close_method_id, ), - ), - timeout=self.connection.connection_tune.heartbeat or None, - ) + ) - await self.__close_event.wait() + await super_close(exc) + return result + + return self.loop.create_task( + asyncio.wait_for(closer(), timeout=timeout) + ) async def basic_get( self, queue: str = "", no_ack: bool = False, diff --git a/aiormq/connection.py b/aiormq/connection.py index b7f4dba..94513e7 100644 --- a/aiormq/connection.py +++ b/aiormq/connection.py @@ -25,10 +25,10 @@ from .abc import ( AbstractChannel, AbstractConnection, ArgumentsType, ChannelFrame, - ExceptionType, SSLCerts, URLorStr, + SSLCerts, URLorStr, ) from .auth import AuthMechanism -from .base import Base, task +from .base import Base from .channel import Channel from .exceptions import ( AMQPConnectionError, AMQPError, AuthenticationError, ConnectionChannelError, @@ -232,8 +232,7 @@ class Connection(Base, AbstractConnection): READER_CLOSE_TIMEOUT = 2 - _reader_task: asyncio.Task - _writer_task: asyncio.Task + loop: asyncio.AbstractEventLoop write_queue: asyncio.Queue server_properties: ArgumentsType connection_tune: spec.Connection.Tune @@ -250,8 +249,7 @@ def __init__( loop: Optional[asyncio.AbstractEventLoop] = None, context: Optional[ssl.SSLContext] = None, ): - - super().__init__(loop=loop or asyncio.get_event_loop(), parent=None) + super().__init__(loop=loop) self.url = URL(url) if self.url.is_absolute() and not self.url.port: @@ -272,12 +270,8 @@ def __init__( verify=self.url.query.get("no_verify_ssl", "0") == "0", ) - self.started = False self.channels = {} - self.write_queue = asyncio.Queue( - maxsize=self.FRAME_BUFFER_SIZE, - ) - + self.write_queue = asyncio.Queue(maxsize=self.FRAME_BUFFER_SIZE) self.last_channel = 1 self.timeout = parse_int(self.url.query.get("timeout", "60")) @@ -286,6 +280,7 @@ def __init__( ) self.last_channel_lock = asyncio.Lock() self.connected = asyncio.Event() + self.closing = self.loop.create_future() self.connection_name = self.url.query.get("name") self.__close_reply_code: int = REPLY_SUCCESS @@ -314,21 +309,6 @@ def set_close_reason( self.__close_class_id = class_id self.__close_method_id = method_id - @property - def is_opened(self) -> bool: - is_reader_running = ( - hasattr(self, "_reader_task") and not self._reader_task.done() - ) - is_writer_running = ( - hasattr(self, "_writer_task") and not self._writer_task.done() - ) - - return ( - is_reader_running and - is_writer_running and - not self.is_closed - ) - def __str__(self) -> str: if self.url.password is not None: return str(self.url.with_password("******")) @@ -410,13 +390,41 @@ async def _rpc( raise ConnectionClosed(frame.reply_code, frame.reply_text) return frame - @task + def close( + self, exc: Optional[BaseException] = asyncio.CancelledError, + timeout: TimeoutType = None, + ) -> Awaitable[spec.Connection.CloseOk]: + super_closer = super().close + + async def closer() -> None: + nonlocal super_closer + + await self.write_queue.put( + ChannelFrame.marshall( + channel_number=0, + frames=[ + spec.Connection.Close( + reply_code=self.__close_reply_code, + reply_text=self.__close_reply_text, + class_id=self.__close_class_id, + method_id=self.__close_method_id, + ) + ] + ), + ) + + try: + return await self.closing + finally: + await super_closer(exc) + + return self.loop.create_task( + asyncio.wait_for(closer(), timeout=timeout) + ) + async def connect( self, client_properties: Optional[FieldTable] = None, ) -> bool: - if self.is_opened: - raise RuntimeError("Connection already opened") - ssl_context = self.ssl_context if ssl_context is None and self.url.scheme == "amqps": @@ -490,45 +498,40 @@ async def connect( if not isinstance(frame, spec.Connection.OpenOk): raise AMQPInternalError("Connection.OpenOk", frame) - except BaseException as e: + except BaseException: await self.__close_writer(writer) - await self.close(e) + await self.close() raise - # noinspection PyAsyncCall - self._reader_task = self.create_task(self.__reader(frame_receiver)) - self._reader_task.add_done_callback(self._on_reader_done) + reader_task = self.create_task(self.__reader(frame_receiver)) + reader_task.add_done_callback(self._on_reader_done) - # noinspection PyAsyncCall - self._writer_task = self.create_task(self.__writer(writer)) + writer_task = self.create_task(self.__writer(writer)) + writer_task.add_done_callback(self._on_writer_done) self.connection_tune = connection_tune self.server_properties = server_properties return True - def _on_reader_done(self, task: asyncio.Task) -> None: - log.debug("Reader exited for %r", self) + @property + def is_opened(self) -> bool: + return self.connected.is_set() - if not task.cancelled() and task.exception() is not None: - log.debug("Cancelling cause reader exited abnormally") - self.set_close_reason( - reply_code=500, reply_text="reader unexpected closed", - ) + def _on_writer_done(self, _: asyncio.Task) -> None: + log.debug("Writer exited for %r", self) + if not self.is_closed: + self.close() - async def close_writer_task() -> None: - if not self._writer_task.done(): - self._writer_task.cancel() - await asyncio.gather(self._writer_task, return_exceptions=True) - try: - exc = task.exception() - except asyncio.CancelledError as e: - exc = e - await self.close(exc) + def _on_reader_done(self, _: asyncio.Task) -> None: + log.debug("Reader exited for %r", self) + if not self.is_closed: + self.close() - self.loop.create_task(close_writer_task()) + async def __handle_close_ok(self, frame: spec.Connection.CloseOk) -> None: + self.connected.clear() - async def __handle_close_ok(self, _: spec.Connection.CloseOk) -> None: - return + if not self.closing.done(): + self.closing.set_result(frame) async def __handle_heartbeat(self, _: Heartbeat) -> None: return @@ -555,6 +558,7 @@ async def __handle_close(self, frame: spec.Connection.Close) -> None: ), ) + self.connected.clear() exception = exception_by_code(frame) if ( @@ -567,7 +571,7 @@ async def __handle_close(self, frame: spec.Connection.Close) -> None: async def __handle_channel_close_ok( self, _: spec.Channel.CloseOk, ) -> None: - self.channels.pop(0, None) + pass async def __handle_channel_update_secret_ok( self, frame: spec.Connection.UpdateSecretOk, @@ -598,7 +602,8 @@ async def __reader(self, frame_receiver: FrameReceiver) -> None: # Not very optimal, but avoid creating a task for each frame sending # noinspection PyAsyncCall - self.create_task(self.__heartbeat()) + heartbeat_task = self.create_task(self.__heartbeat()) + heartbeat_task.add_done_callback(lambda _: self.close()) channel_frame_handlers: Mapping[Any, Callable[[Any], Awaitable[None]]] channel_frame_handlers = { @@ -649,7 +654,6 @@ async def __reader(self, frame_receiver: FrameReceiver) -> None: "Server connection %r was stuck. No frames were received " "in %d seconds.", self, self.__heartbeat_grace_timeout, ) - self._writer_task.cancel() raise @property @@ -665,7 +669,6 @@ async def __heartbeat(self) -> None: while not self.closing.done(): if self.is_connection_was_stuck: - self._reader_task.cancel() return await asyncio.sleep(heartbeat_timeout) @@ -676,7 +679,6 @@ async def __heartbeat(self) -> None: timeout=self.__heartbeat_grace_timeout, ) except asyncio.TimeoutError: - self._reader_task.cancel() return async def __writer(self, writer: asyncio.StreamWriter) -> None: @@ -708,6 +710,9 @@ async def __writer(self, writer: asyncio.StreamWriter) -> None: if not self.__check_writer(writer) or self.is_connection_was_stuck: raise + if self.closing.done(): + raise + frame = spec.Connection.Close( reply_code=self.__close_reply_code, reply_text=self.__close_reply_text, @@ -756,20 +761,6 @@ def __check_writer(writer: asyncio.StreamWriter) -> bool: return writer.can_write_eof() - async def _on_close( - self, - ex: Optional[ExceptionType] = ConnectionClosed(0, "normal closed"), - ) -> None: - log.debug("Closing connection %r cause: %r", self, ex) - if not self._reader_task.done(): - self._reader_task.cancel() - if not self._writer_task.done(): - self._writer_task.cancel() - - await asyncio.gather( - self._reader_task, self._writer_task, return_exceptions=True, - ) - @property def server_capabilities(self) -> ArgumentsType: return self.server_properties["capabilities"] # type: ignore @@ -793,7 +784,6 @@ def publisher_confirms(self) -> Optional[bool]: return None return bool(publisher_confirms) - @task async def channel( self, channel_number: Optional[int] = None, @@ -806,7 +796,7 @@ async def channel( await self.connected.wait() if self.is_closed: - raise RuntimeError("%r closed" % self) + raise RuntimeError(f"{self!r} closed") if not self.publisher_confirms and publisher_confirms: raise ValueError("Server doesn't support publisher_confirms") @@ -882,13 +872,13 @@ async def __aenter__(self) -> AbstractConnection: await self.connect() return self - async def __aexit__( + def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], - ) -> None: - await self.close(exc_val) + ) -> Awaitable[Any]: + return self.close() async def connect( diff --git a/tests/test_future_store.py b/tests/test_future_store.py deleted file mode 100644 index 6a45d05..0000000 --- a/tests/test_future_store.py +++ /dev/null @@ -1,88 +0,0 @@ -import asyncio - -import pytest - -from aiormq.base import FutureStore - - -@pytest.fixture -async def root_store(event_loop: asyncio.AbstractEventLoop): - store = FutureStore(loop=event_loop) - try: - yield store - finally: - await store.reject_all(Exception("Cancelling")) - - -@pytest.fixture -async def child_store(event_loop, root_store: FutureStore): - store = root_store.get_child() - try: - yield store - finally: - await store.reject_all(Exception("Cancelling")) - - -async def test_reject_all( - event_loop, root_store: FutureStore, child_store: FutureStore, -): - - future1 = root_store.create_future() - future2 = child_store.create_future() - - assert root_store.futures - assert child_store.futures - - await root_store.reject_all(RuntimeError) - await asyncio.sleep(0.1) - - assert isinstance(future1.exception(), RuntimeError) - assert isinstance(future2.exception(), RuntimeError) - assert not root_store.futures - assert not child_store.futures - - -async def test_result( - event_loop, root_store: FutureStore, child_store: FutureStore, -): - async def result(): - await asyncio.sleep(0.1) - return "result" - - assert await child_store.create_task(result()) == "result" - - -async def test_siblings( - event_loop, root_store: FutureStore, child_store: FutureStore, -): - async def coro(store): - await asyncio.sleep(0.1) - await store.reject_all(RuntimeError) - - task1 = child_store.create_task(coro(child_store)) - assert root_store.futures - assert child_store.futures - - with pytest.raises(RuntimeError): - await task1 - - await asyncio.sleep(0.1) - - assert not root_store.futures - assert not child_store.futures - - child = child_store.get_child().get_child().get_child() - task = child.create_task(coro(child)) - - assert root_store.futures - assert child_store.futures - assert child.futures - - with pytest.raises(RuntimeError): - await task - - await asyncio.sleep(0.1) - - assert not root_store.futures - assert not child_store.futures - assert not child.futures