Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix aiohttp wait for closed ssl connections #153

64 changes: 64 additions & 0 deletions gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import functools
import io
import json
import logging
Expand Down Expand Up @@ -44,6 +46,7 @@ def __init__(
auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, bool, Fingerprint] = False,
timeout: Optional[int] = None,
ssl_close_timeout: Optional[Union[int, float]] = 10,
client_session_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the transport with the given aiohttp parameters.
Expand All @@ -53,6 +56,8 @@ def __init__(
:param cookies: Dict of HTTP cookies.
:param auth: BasicAuth object to enable Basic HTTP auth if needed
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
:param ssl_close_timeout: Timeout in seconds to wait for the ssl connection
to close properly
:param client_session_args: Dict of extra args passed to
`aiohttp.ClientSession`_

Expand All @@ -65,6 +70,7 @@ def __init__(
self.auth: Optional[BasicAuth] = auth
self.ssl: Union[SSLContext, bool, Fingerprint] = ssl
self.timeout: Optional[int] = timeout
self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout
self.client_session_args = client_session_args
self.session: Optional[aiohttp.ClientSession] = None

Expand Down Expand Up @@ -100,6 +106,59 @@ async def connect(self) -> None:
else:
raise TransportAlreadyConnected("Transport is already connected")

@staticmethod
def create_aiohttp_closed_event(session) -> asyncio.Event:
"""Work around aiohttp issue that doesn't properly close transports on exit.

See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209

Returns:
An event that will be set once all transports have been properly closed.
"""

ssl_transports = 0
all_is_lost = asyncio.Event()

def connection_lost(exc, orig_lost):
nonlocal ssl_transports

try:
orig_lost(exc)
finally:
ssl_transports -= 1
if ssl_transports == 0:
all_is_lost.set()

def eof_received(orig_eof_received):
try:
orig_eof_received()
except AttributeError: # pragma: no cover
# It may happen that eof_received() is called after
# _app_protocol and _transport are set to None.
pass

for conn in session.connector._conns.values():
for handler, _ in conn:
proto = getattr(handler.transport, "_ssl_protocol", None)
if proto is None:
continue

ssl_transports += 1
orig_lost = proto.connection_lost
orig_eof_received = proto.eof_received

proto.connection_lost = functools.partial(
connection_lost, orig_lost=orig_lost
)
proto.eof_received = functools.partial(
eof_received, orig_eof_received=orig_eof_received
)

if ssl_transports == 0:
all_is_lost.set()

return all_is_lost

async def close(self) -> None:
"""Coroutine which will close the aiohttp session.

Expand All @@ -108,7 +167,12 @@ async def close(self) -> None:
when you exit the async context manager.
"""
if self.session is not None:
closed_event = self.create_aiohttp_closed_event(self.session)
await self.session.close()
try:
await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout)
except asyncio.TimeoutError:
pass
self.session = None

async def execute(
Expand Down
54 changes: 37 additions & 17 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(skip_transport)


@pytest.fixture
async def aiohttp_server():
async def aiohttp_server_base(with_ssl=False):
"""Factory to create a TestServer instance, given an app.

aiohttp_server(app, **kwargs)
Expand All @@ -89,7 +88,13 @@ async def aiohttp_server():

async def go(app, *, port=None, **kwargs): # type: ignore
server = AIOHTTPTestServer(app, port=port)
await server.start_server(**kwargs)

start_server_args = {**kwargs}
if with_ssl:
testcert, ssl_context = get_localhost_ssl_context()
start_server_args["ssl"] = ssl_context

await server.start_server(**start_server_args)
servers.append(server)
return server

Expand All @@ -99,6 +104,18 @@ async def go(app, *, port=None, **kwargs): # type: ignore
await servers.pop().close()


@pytest.fixture
async def aiohttp_server():
async for server in aiohttp_server_base():
yield server


@pytest.fixture
async def ssl_aiohttp_server():
async for server in aiohttp_server_base(with_ssl=True):
yield server


# Adding debug logs to websocket tests
for name in [
"websockets.legacy.server",
Expand All @@ -121,6 +138,22 @@ async def go(app, *, port=None, **kwargs): # type: ignore
MS = 0.001 * int(os.environ.get("GQL_TESTS_TIMEOUT_FACTOR", 1))


def get_localhost_ssl_context():
# This is a copy of certificate from websockets tests folder
#
# Generate TLS certificate with:
# $ openssl req -x509 -config test_localhost.cnf \
# -days 15340 -newkey rsa:2048 \
# -out test_localhost.crt -keyout test_localhost.key
# $ cat test_localhost.key test_localhost.crt > test_localhost.pem
# $ rm test_localhost.key test_localhost.crt
testcert = bytes(pathlib.Path(__file__).with_name("test_localhost.pem"))
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(testcert)

return (testcert, ssl_context)


class WebSocketServer:
"""Websocket server on localhost on a free port.

Expand All @@ -141,20 +174,7 @@ async def start(self, handler, extra_serve_args=None):
extra_serve_args = {}

if self.with_ssl:
# This is a copy of certificate from websockets tests folder
#
# Generate TLS certificate with:
# $ openssl req -x509 -config test_localhost.cnf \
# -days 15340 -newkey rsa:2048 \
# -out test_localhost.crt -keyout test_localhost.key
# $ cat test_localhost.key test_localhost.crt > test_localhost.pem
# $ rm test_localhost.key test_localhost.crt
self.testcert = bytes(
pathlib.Path(__file__).with_name("test_localhost.pem")
)
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(self.testcert)

self.testcert, ssl_context = get_localhost_ssl_context()
extra_serve_args["ssl"] = ssl_context

# Start a server with a random open port
Expand Down
35 changes: 35 additions & 0 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,3 +1073,38 @@ async def handler(request):
execution_result = await session.execute(query, get_execution_result=True)

assert execution_result.extensions["key1"] == "val1"


@pytest.mark.asyncio
@pytest.mark.parametrize("ssl_close_timeout", [0, 10])
async def test_aiohttp_query_https(event_loop, ssl_aiohttp_server, ssl_close_timeout):
from aiohttp import web
from gql.transport.aiohttp import AIOHTTPTransport

async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await ssl_aiohttp_server(app)

url = server.make_url("/")

assert str(url).startswith("https://")

sample_transport = AIOHTTPTransport(
url=url, timeout=10, ssl_close_timeout=ssl_close_timeout
)

async with Client(transport=sample_transport,) as session:

query = gql(query1_str)

# Execute query asynchronously
result = await session.execute(query)

continents = result["continents"]

africa = continents[0]

assert africa["code"] == "AF"