From 84f1af2d840aa14cde989641eaa265fe34bdf4ae Mon Sep 17 00:00:00 2001 From: Ondrej Novak Date: Thu, 31 Oct 2024 15:35:35 +0100 Subject: [PATCH] Add abstract `TransportFactory` --- aiormq/__init__.py | 3 +- aiormq/connection.py | 60 ++++++++++++++++++++++++++++++++++------ tests/test_connection.py | 40 +++++++++++++++++++++++++-- 3 files changed, 92 insertions(+), 11 deletions(-) diff --git a/aiormq/__init__.py b/aiormq/__init__.py index b9f9b12..605e835 100644 --- a/aiormq/__init__.py +++ b/aiormq/__init__.py @@ -2,7 +2,7 @@ from . import abc from .channel import Channel -from .connection import Connection, connect +from .connection import Connection, connect, TransportFactory from .exceptions import ( AMQPChannelError, AMQPConnectionError, AMQPError, AMQPException, AuthenticationError, ChannelAccessRefused, ChannelClosed, @@ -34,6 +34,7 @@ "ConnectionChannelError", "ConnectionClosed", "ConnectionCommandInvalid", + "TransportFactory", "ConnectionFrameError", "ConnectionInternalError", "ConnectionNotAllowed", diff --git a/aiormq/connection.py b/aiormq/connection.py index be2e927..fc97ba7 100644 --- a/aiormq/connection.py +++ b/aiormq/connection.py @@ -3,6 +3,7 @@ import platform import ssl import sys +from abc import abstractmethod, ABC from base64 import b64decode from collections.abc import AsyncIterable from contextlib import suppress @@ -246,6 +247,45 @@ async def __anext__(self) -> ChannelFrame: return frame +class TransportFactory(ABC): + """ + Abstract factory class allowing to open connections with generic + transports. + """ + + @staticmethod + @abstractmethod + def is_ssl_url(url: URL) -> bool: + pass + + @abstractmethod + async def create( + self, + url: URL, ssl: Optional[ssl.SSLContext], + **kwargs: dict[str, Any] + ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: + """Create a transport connection to the AMQP server.""" + pass + + +class _DefaultTransportFactory(TransportFactory): + @staticmethod + def is_ssl_url(url: URL) -> bool: + return url.scheme == "amqps" + + async def create( + self, + url: URL, ssl: Optional[ssl.SSLContext], + **kwargs: Any + ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: + try: + return await asyncio.open_connection( + host=url.host, port=url.port, ssl=ssl, **kwargs, + ) + except OSError as e: + raise AMQPConnectionError(*e.args) from e + + class Connection(Base, AbstractConnection): FRAME_BUFFER_SIZE = 10 # Interval between sending heartbeats based on the heartbeat(timeout) @@ -274,6 +314,7 @@ def __init__( *, loop: Optional[asyncio.AbstractEventLoop] = None, context: Optional[ssl.SSLContext] = None, + transport_factory: Optional[TransportFactory] = None, **create_connection_kwargs: Any, ): @@ -315,6 +356,9 @@ def __init__( self.last_channel_lock = asyncio.Lock() self.connected = asyncio.Event() self.connection_name = self.url.query.get("name") + self._transport_factory = ( + transport_factory or _DefaultTransportFactory() + ) self.__close_reply_code: int = REPLY_SUCCESS self.__close_reply_text: str = "normally closed" @@ -445,8 +489,8 @@ async def connect( raise RuntimeError("Connection already opened") ssl_context = self.ssl_context - - if ssl_context is None and self.url.scheme == "amqps": + is_ssl_url = self._transport_factory.is_ssl_url(self.url) + if ssl_context is None and is_ssl_url: ssl_context = await self.loop.run_in_executor( None, self._get_ssl_context, ) @@ -454,15 +498,15 @@ async def connect( log.debug("Connecting to: %s", self) try: - reader, writer = await asyncio.open_connection( - self.url.host, self.url.port, ssl=ssl_context, + reader, writer = await self._transport_factory.create( + self.url, ssl=self.ssl_context, **self.__create_connection_kwargs, ) + except Exception as e: + log.error("error when creating transport: %r", e) + raise e - frame_receiver = FrameReceiver(reader) - except OSError as e: - raise AMQPConnectionError(*e.args) from e - + frame_receiver = FrameReceiver(reader) frame: Optional[FrameTypes] try: diff --git a/tests/test_connection.py b/tests/test_connection.py index ad82610..dece4a4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,7 +4,7 @@ import ssl import uuid from binascii import hexlify -from typing import Optional +from typing import Any, Optional import aiomisc import pytest @@ -14,7 +14,12 @@ import aiormq from aiormq.abc import DeliveredMessage from aiormq.auth import AuthBase, ExternalAuth, PlainAuth -from aiormq.connection import parse_int, parse_timeout, parse_bool +from aiormq.connection import ( + TransportFactory, + parse_int, + parse_timeout, + parse_bool +) from .conftest import AMQP_URL, cert_path, skip_when_quick_test @@ -119,6 +124,37 @@ async def test_open(amqp_connection): await amqp_connection.close() +class _TcpTransportFactory(TransportFactory): + @staticmethod + def is_ssl_url(url: URL) -> bool: + return url.scheme == "amqps" + + async def create( + self, + url: URL, ssl: Optional[ssl.SSLContext], + **kwargs: dict[str, Any] + ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: + loop = asyncio.get_event_loop() + reader = asyncio.StreamReader(loop=loop) + protocol = asyncio.StreamReaderProtocol(reader, loop=loop) + transport, _ = await loop.create_connection( + lambda: protocol, url.host, url.port, ssl=ssl, + ) + writer = asyncio.StreamWriter(transport, protocol, reader, loop) + return reader, writer + + +async def test_open_with_transport_factory(amqp_url): + amqp_connection = await aiormq.connect( + amqp_url, + transport_factory=_TcpTransportFactory(), + ) + + channel = await amqp_connection.channel() + await channel.close() + await amqp_connection.close() + + async def test_channel_close(amqp_connection): channel = await amqp_connection.channel()