Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow test for python 3.8 #69

Merged
merged 3 commits into from
Nov 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion .drone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ steps:
- name: cache
path: /drone/src/.tox

- name: python 3.8
image: snakepacker/python:all
pull: always
commands:
- wait-for-port rabbitmq:5672 rabbitmq:5671
- tox
environment:
AMQP_URL: amqp://guest:guest@rabbitmq
TOXENV: py38
COVERALLS_REPO_TOKEN:
from_secret: COVERALLS_TOKEN
volumes:
- name: cache
path: /drone/src/.tox

- name: python 3.7
image: snakepacker/python:all
pull: always
Expand Down Expand Up @@ -89,6 +104,21 @@ steps:
- name: cache
path: /drone/src/.tox

- name: python 3.8 uvloop
image: snakepacker/python:all
pull: always
commands:
- wait-for-port rabbitmq:5672 rabbitmq:5671
- tox
environment:
AMQP_URL: amqp://guest:guest@rabbitmq
TOXENV: py38-uvloop
COVERALLS_REPO_TOKEN:
from_secret: COVERALLS_TOKEN
volumes:
- name: cache
path: /drone/src/.tox

- name: python 3.7 uvloop
image: snakepacker/python:all
pull: always
Expand Down Expand Up @@ -161,6 +191,6 @@ services:

---
kind: signature
hmac: 65654309225351bc649874b6b007fb576e6d0c10cb5e5b79584305ad931b0268
hmac: db78b836fb377a913c4eed3e22177b502268738102c19982dd2c3302510cd7f1

...
4 changes: 2 additions & 2 deletions aiormq/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pamqp import specification as spec, ContentHeader
from pamqp.body import ContentBody

from aiormq.tools import LazyCoroutine
from aiormq.tools import LazyCoroutine, awaitable
from . import exceptions as exc
from .base import Base, task
from .types import (
Expand Down Expand Up @@ -357,7 +357,7 @@ async def basic_consume(
if consumer_tag in self.consumers:
raise exc.DuplicateConsumerTag(self.number)

self.consumers[consumer_tag] = consumer_callback
self.consumers[consumer_tag] = awaitable(consumer_callback)

# noinspection PyTypeChecker
return await self.rpc(
Expand Down
66 changes: 30 additions & 36 deletions aiormq/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from .auth import AuthMechanism
from .base import Base, task
from .channel import Channel
from .tools import censor_url, shield
from .tools import censor_url
from .types import ArgumentsType, SSLCerts, URLorStr
from .version import __version__

log = logging.getLogger(__name__)


CHANNEL_CLOSE_RESPONSES = (spec.Channel.Close, spec.Channel.CloseOk)

try:
from yarl import DEFAULT_PORTS

Expand Down Expand Up @@ -59,7 +61,6 @@ def parse_connection_name(conn_name: str):


class Connection(Base):
CLOSE_TIMEOUT = 1.0
FRAME_BUFFER = 10
# Interval between sending heartbeats based on the heartbeat(timeout)
HEARTBEAT_INTERVAL_MULTIPLIER = 0.5
Expand Down Expand Up @@ -176,17 +177,10 @@ def _client_properties(self, **kwargs):
"publisher_confirms": True,
},
"information": "See https://github.com/mosquito/aiormq/",
"client_properties": {},
}

properties["client_properties"].update(
parse_connection_name(self.connection_name)
)

properties["client_properties"].update(
kwargs.get("client_properties", {})
)

properties.update(parse_connection_name(self.connection_name))
properties.update(kwargs.get("client_properties", {}))
return properties

@staticmethod
Expand All @@ -199,7 +193,6 @@ def _credentials_class(start_frame: spec.Connection.Start):
start_frame.mechanisms, [m.name for m in AuthMechanism]
)

@shield
async def __rpc(self, request: spec.Frame, wait_response=True):
self.writer.write(pamqp.frame.marshal(request, 0))

Expand All @@ -212,9 +205,13 @@ async def __rpc(self, request: spec.Frame, wait_response=True):
raise spec.AMQPInternalError(frame, dict(frame))
elif isinstance(frame, spec.Connection.Close):
if frame.reply_code == 403:
raise exc.ProbableAuthenticationError(frame.reply_text)
err = exc.ProbableAuthenticationError(frame.reply_text)
else:
err = exc.ConnectionClosed(frame.reply_code, frame.reply_text)

await self.close(err)

raise exc.ConnectionClosed(frame.reply_code, frame.reply_text)
raise err
return frame

@task
Expand Down Expand Up @@ -334,6 +331,9 @@ async def __receive_frame(self) -> typing.Tuple[int, int, spec.Frame]:
if frame_header == b"\0x00":
raise spec.AMQPFrameError(await self.reader.read())

if self.reader is None:
raise ConnectionError

frame_header += await self.reader.readexactly(6)

if not self.started and frame_header.startswith(b"AMQP"):
Expand Down Expand Up @@ -399,12 +399,7 @@ async def __reader(self):

ch = self.channels[channel]

channel_close_responses = (
spec.Channel.Close,
spec.Channel.CloseOk,
)

if isinstance(frame, channel_close_responses):
if isinstance(frame, CHANNEL_CLOSE_RESPONSES):
self.channels[channel] = None

await ch.frames.put((weight, frame))
Expand All @@ -419,13 +414,13 @@ async def __reader(self):

@staticmethod
async def __close_writer(writer: asyncio.StreamWriter):
writer.close()

wait_closed = getattr(writer, "wait_closed", None)
if not wait_closed:
if writer is None:
return

return await wait_closed()
writer.close()

if hasattr(writer, "wait_closed"):
await writer.wait_closed()

async def _on_close(self, ex=exc.ConnectionClosed(0, "normal closed")):
frame = (
Expand All @@ -434,24 +429,20 @@ async def _on_close(self, ex=exc.ConnectionClosed(0, "normal closed")):
else spec.Connection.Close()
)

await asyncio.wait(
{
asyncio.gather(
self.__rpc(frame, wait_response=False),
return_exceptions=True,
),
self._reader_task,
},
timeout=Connection.CLOSE_TIMEOUT,
return_when=asyncio.ALL_COMPLETED,
await asyncio.gather(
self.__rpc(frame, wait_response=False), return_exceptions=True
)

writer = self.writer
self.reader = None
self.writer = None
self._reader_task = None

await self.__close_writer(writer)
await asyncio.gather(
self.__close_writer(writer), return_exceptions=True
)

await asyncio.gather(self._reader_task, return_exceptions=True)

@property
def server_capabilities(self) -> ArgumentsType:
Expand Down Expand Up @@ -491,6 +482,9 @@ async def channel(

if channel_number is None:
async with self.last_channel_lock:
if self.channels:
self.last_channel = max(self.channels.keys())

while self.last_channel in self.channels.keys():
self.last_channel += 1

Expand Down
24 changes: 20 additions & 4 deletions aiormq/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,28 @@ def censor_url(url: URL):


def shield(func):
async def awaiter(future):
return await future

@wraps(func)
def wrap(*args, **kwargs):
return wraps(func)(awaiter)(asyncio.shield(func(*args, **kwargs)))
return asyncio.shield(awaitable(func)(*args, **kwargs))

return wrap


def awaitable(func):
# Avoid python 3.8+ warning
if asyncio.iscoroutinefunction(func):
return func

@wraps(func)
async def wrap(*args, **kwargs):
result = func(*args, **kwargs)

if hasattr(result, "__await__"):
return await result
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
return await result

return result

return wrap

Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[pytest]
markers =
no_catch_loop_exceptions: no catch unhandled exceptions from event loop
allow_get_event_loop: allow to call asyncio.get_event_loop()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: Implementation :: PyPy",
"Programming Language :: Python :: Implementation :: CPython",
],
Expand Down
4 changes: 2 additions & 2 deletions tests/certs/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM rabbitmq:3-management-alpine
FROM rabbitmq:3.8-management-alpine

RUN mkdir -p /certs
RUN mkdir -p /certs/

COPY tests/certs/ca.pem /certs/
COPY tests/certs/server.key /certs/
Expand Down
24 changes: 19 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import pprint

import gc
import logging
import os
Expand Down Expand Up @@ -42,15 +44,20 @@ def event_loop(request):
def getter_mock():
raise RuntimeError("asyncio.get_event_loop() call forbidden")

asyncio.get_event_loop = getter_mock
if not request.node.get_closest_marker("allow_get_event_loop"):
asyncio.get_event_loop = getter_mock

nocatch_marker = request.node.get_closest_marker(
"no_catch_loop_exceptions"
)

def on_exception(loop, err):
logging.exception("%s", pprint.pformat(err))
exceptions.append(err)

exceptions = list()
if not nocatch_marker:
loop.set_exception_handler(lambda l, c: exceptions.append(c))
loop.set_exception_handler(on_exception)

try:
yield loop
Expand All @@ -59,7 +66,9 @@ def getter_mock():
raise RuntimeError(exceptions)

finally:
asyncio.get_event_loop = original
if not request.node.get_closest_marker("allow_get_event_loop"):
asyncio.get_event_loop = original

asyncio.set_event_loop_policy(None)
del policy

Expand Down Expand Up @@ -100,9 +109,14 @@ def cert_path(*args):


@pytest.fixture(params=amqp_url_list, ids=amqp_url_ids)
async def amqp_url(request):
return request.param


@pytest.fixture
@async_generator
async def amqp_connection(request, event_loop):
connection = Connection(request.param, loop=event_loop)
async def amqp_connection(amqp_url, event_loop):
connection = Connection(amqp_url, loop=event_loop)

await connection.connect()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def test_ack_nack_reject(amqp_channel: aiormq.Channel, event_loop):
await channel.basic_qos(prefetch_count=1)

declare_ok = await channel.queue_declare(auto_delete=True)
queue = asyncio.Queue(loop=event_loop)
queue = asyncio.Queue()

await channel.basic_consume(declare_ok.queue, queue.put, no_ack=False)

Expand Down
17 changes: 10 additions & 7 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,19 @@ async def test_non_publisher_confirms(amqp_connection):

@skip_when_quick_test
@pytest.mark.no_catch_loop_exceptions
async def test_no_free_channels(amqp_connection):
await asyncio.wait(
[
amqp_connection.channel(n + 1)
for n in range(amqp_connection.connection_tune.channel_max)
]
async def test_no_free_channels(amqp_connection: aiormq.Connection):
await asyncio.wait_for(
asyncio.wait(
[
amqp_connection.channel(n + 1)
for n in range(amqp_connection.connection_tune.channel_max)
]
),
timeout=60,
)

with pytest.raises(aiormq.exceptions.ConnectionNotAllowed):
await amqp_connection.channel()
await asyncio.wait_for(amqp_connection.channel(), timeout=5)


async def test_huge_message(amqp_connection: aiormq.Connection):
Expand Down
Loading