Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add abstract TransportFactory #205

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aiormq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from . import abc
from .channel import Channel
from .connection import Connection, connect
from .connection import Connection, connect, TransportFactory
from .exceptions import (
AMQPChannelError, AMQPConnectionError, AMQPError, AMQPException,
AuthenticationError, ChannelAccessRefused, ChannelClosed,
Expand Down Expand Up @@ -34,6 +34,7 @@
"ConnectionChannelError",
"ConnectionClosed",
"ConnectionCommandInvalid",
"TransportFactory",
"ConnectionFrameError",
"ConnectionInternalError",
"ConnectionNotAllowed",
Expand Down
60 changes: 52 additions & 8 deletions aiormq/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import platform
import ssl
import sys
from abc import abstractmethod, ABC
from base64 import b64decode
from collections.abc import AsyncIterable
from contextlib import suppress
Expand Down Expand Up @@ -246,6 +247,45 @@ async def __anext__(self) -> ChannelFrame:
return frame


class TransportFactory(ABC):
"""
Abstract factory class allowing to open connections with generic
transports.
"""

@staticmethod
@abstractmethod
def is_ssl_url(url: URL) -> bool:
pass

@abstractmethod
async def create(
self,
url: URL, ssl: Optional[ssl.SSLContext],
**kwargs: dict[str, Any]
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""Create a transport connection to the AMQP server."""
pass


class _DefaultTransportFactory(TransportFactory):
@staticmethod
def is_ssl_url(url: URL) -> bool:
return url.scheme == "amqps"

async def create(
self,
url: URL, ssl: Optional[ssl.SSLContext],
**kwargs: Any
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
try:
return await asyncio.open_connection(
host=url.host, port=url.port, ssl=ssl, **kwargs,
)
except OSError as e:
raise AMQPConnectionError(*e.args) from e


class Connection(Base, AbstractConnection):
FRAME_BUFFER_SIZE = 10
# Interval between sending heartbeats based on the heartbeat(timeout)
Expand Down Expand Up @@ -274,6 +314,7 @@ def __init__(
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
context: Optional[ssl.SSLContext] = None,
transport_factory: Optional[TransportFactory] = None,
**create_connection_kwargs: Any,
):

Expand Down Expand Up @@ -315,6 +356,9 @@ def __init__(
self.last_channel_lock = asyncio.Lock()
self.connected = asyncio.Event()
self.connection_name = self.url.query.get("name")
self._transport_factory = (
transport_factory or _DefaultTransportFactory()
)

self.__close_reply_code: int = REPLY_SUCCESS
self.__close_reply_text: str = "normally closed"
Expand Down Expand Up @@ -445,24 +489,24 @@ async def connect(
raise RuntimeError("Connection already opened")

ssl_context = self.ssl_context

if ssl_context is None and self.url.scheme == "amqps":
is_ssl_url = self._transport_factory.is_ssl_url(self.url)
if ssl_context is None and is_ssl_url:
ssl_context = await self.loop.run_in_executor(
None, self._get_ssl_context,
)
self.ssl_context = ssl_context

log.debug("Connecting to: %s", self)
try:
reader, writer = await asyncio.open_connection(
self.url.host, self.url.port, ssl=ssl_context,
reader, writer = await self._transport_factory.create(
self.url, ssl=self.ssl_context,
**self.__create_connection_kwargs,
)
except Exception as e:
log.error("error when creating transport: %r", e)
raise e

frame_receiver = FrameReceiver(reader)
except OSError as e:
raise AMQPConnectionError(*e.args) from e

frame_receiver = FrameReceiver(reader)
frame: Optional[FrameTypes]

try:
Expand Down
40 changes: 38 additions & 2 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ssl
import uuid
from binascii import hexlify
from typing import Optional
from typing import Any, Optional

import aiomisc
import pytest
Expand All @@ -14,7 +14,12 @@
import aiormq
from aiormq.abc import DeliveredMessage
from aiormq.auth import AuthBase, ExternalAuth, PlainAuth
from aiormq.connection import parse_int, parse_timeout, parse_bool
from aiormq.connection import (
TransportFactory,
parse_int,
parse_timeout,
parse_bool
)

from .conftest import AMQP_URL, cert_path, skip_when_quick_test

Expand Down Expand Up @@ -119,6 +124,37 @@ async def test_open(amqp_connection):
await amqp_connection.close()


class _TcpTransportFactory(TransportFactory):
@staticmethod
def is_ssl_url(url: URL) -> bool:
return url.scheme == "amqps"

async def create(
self,
url: URL, ssl: Optional[ssl.SSLContext],
**kwargs: dict[str, Any]
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
loop = asyncio.get_event_loop()
reader = asyncio.StreamReader(loop=loop)
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
transport, _ = await loop.create_connection(
lambda: protocol, url.host, url.port, ssl=ssl,
)
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
return reader, writer


async def test_open_with_transport_factory(amqp_url):
amqp_connection = await aiormq.connect(
amqp_url,
transport_factory=_TcpTransportFactory(),
)

channel = await amqp_connection.channel()
await channel.close()
await amqp_connection.close()


async def test_channel_close(amqp_connection):
channel = await amqp_connection.channel()

Expand Down