Skip to content

Commit

Permalink
Add correct kwargs to ClientSession init to allow graceful AsyncHTTPP…
Browse files Browse the repository at this point in the history
…rovider shutdown (#3557)
  • Loading branch information
kclowes authored Dec 19, 2024
1 parent 0893752 commit 3b9c0f0
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/providers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ AsyncHTTPProvider
>>> # If you want to pass in your own session:
>>> custom_session = ClientSession()
>>> await w3.provider.cache_async_session(custom_session) # This method is an async method so it needs to be handled accordingly
>>> # when you're finished, disconnect:
>>> w3.provider.disconnect()
Under the hood, the ``AsyncHTTPProvider`` uses the python
`aiohttp <https://docs.aiohttp.org/en/stable/>`_ library for making requests.
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3557.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a ``disconnect`` method to the AsyncHTTPProvider that closes all sessions and clears the cache
40 changes: 39 additions & 1 deletion tests/integration/go_ethereum/test_goethereum_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,45 @@ class TestGoEthereumAsyncNetModuleTest(GoEthereumAsyncNetModuleTest):


class TestGoEthereumAsyncEthModuleTest(GoEthereumAsyncEthModuleTest):
pass
@pytest.mark.asyncio
async def test_async_http_provider_disconnects_gracefully(
self, async_w3, endpoint_uri
) -> None:
w3_1 = async_w3

w3_2 = AsyncWeb3(AsyncHTTPProvider(endpoint_uri))
assert w3_1 != w3_2

await w3_1.eth.get_block("latest")
await w3_2.eth.get_block("latest")

w3_1_session_cache = w3_1.provider._request_session_manager.session_cache
w3_2_session_cache = w3_2.provider._request_session_manager.session_cache

for _, session in w3_1_session_cache.items():
assert not session.closed
for _, session in w3_2_session_cache.items():
assert not session.closed
assert w3_1_session_cache != w3_2_session_cache

await w3_1.provider.disconnect()
await w3_2.provider.disconnect()

assert len(w3_1_session_cache) == 0
assert len(w3_2_session_cache) == 0

@pytest.mark.asyncio
async def test_async_http_provider_reuses_cached_session(self, async_w3) -> None:
await async_w3.eth.get_block("latest")
session_cache = async_w3.provider._request_session_manager.session_cache
assert len(session_cache) == 1
session = list(session_cache._data.values())[0]

await async_w3.eth.get_block("latest")
assert len(session_cache) == 1
assert session == list(session_cache._data.values())[0]
await async_w3.provider.disconnect()
assert len(session_cache) == 0


class TestGoEthereumAsyncTxPoolModuleTest(GoEthereumAsyncTxPoolModuleTest):
Expand Down
15 changes: 13 additions & 2 deletions web3/_utils/http_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ClientResponse,
ClientSession,
ClientTimeout,
TCPConnector,
)
from eth_typing import (
URI,
Expand Down Expand Up @@ -174,7 +175,12 @@ async def async_cache_and_return_session(
async with async_lock(self.session_pool, self._lock):
if cache_key not in self.session_cache:
if session is None:
session = ClientSession(raise_for_status=True)
session = ClientSession(
raise_for_status=True,
connector=TCPConnector(
force_close=True, enable_cleanup_closed=True
),
)

cached_session, evicted_items = self.session_cache.cache(
cache_key, session
Expand Down Expand Up @@ -213,7 +219,12 @@ async def async_cache_and_return_session(
)

# replace stale session with a new session at the cache key
_session = ClientSession(raise_for_status=True)
_session = ClientSession(
raise_for_status=True,
connector=TCPConnector(
force_close=True, enable_cleanup_closed=True
),
)
cached_session, evicted_items = self.session_cache.cache(
cache_key, _session
)
Expand Down
8 changes: 8 additions & 0 deletions web3/providers/rpc/async_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,11 @@ async def make_batch_request(
self.logger.debug("Received batch response HTTP.")
responses_list = cast(List[RPCResponse], self.decode_rpc_response(raw_response))
return sort_batch_response_by_response_ids(responses_list)

async def disconnect(self) -> None:
cache = self._request_session_manager.session_cache
for _, session in cache.items():
await session.close()
cache.clear()

self.logger.info(f"Successfully disconnected from: {self.endpoint_uri}")

0 comments on commit 3b9c0f0

Please sign in to comment.