From 3b9c0f09553b2f30138285466a3b8fb8c62db600 Mon Sep 17 00:00:00 2001 From: kclowes Date: Thu, 19 Dec 2024 11:35:16 -0700 Subject: [PATCH] Add correct kwargs to ClientSession init to allow graceful AsyncHTTPProvider shutdown (#3557) --- docs/providers.rst | 2 + newsfragments/3557.bugfix.rst | 1 + .../go_ethereum/test_goethereum_http.py | 40 ++++++++++++++++++- web3/_utils/http_session_manager.py | 15 ++++++- web3/providers/rpc/async_rpc.py | 8 ++++ 5 files changed, 63 insertions(+), 3 deletions(-) create mode 100644 newsfragments/3557.bugfix.rst diff --git a/docs/providers.rst b/docs/providers.rst index 0e9a45854e..8489948d79 100644 --- a/docs/providers.rst +++ b/docs/providers.rst @@ -214,6 +214,8 @@ AsyncHTTPProvider >>> # If you want to pass in your own session: >>> custom_session = ClientSession() >>> await w3.provider.cache_async_session(custom_session) # This method is an async method so it needs to be handled accordingly + >>> # when you're finished, disconnect: + >>> w3.provider.disconnect() Under the hood, the ``AsyncHTTPProvider`` uses the python `aiohttp `_ library for making requests. diff --git a/newsfragments/3557.bugfix.rst b/newsfragments/3557.bugfix.rst new file mode 100644 index 0000000000..0b6b22992a --- /dev/null +++ b/newsfragments/3557.bugfix.rst @@ -0,0 +1 @@ +Add a ``disconnect`` method to the AsyncHTTPProvider that closes all sessions and clears the cache diff --git a/tests/integration/go_ethereum/test_goethereum_http.py b/tests/integration/go_ethereum/test_goethereum_http.py index 51e18c9bda..42b7959bdb 100644 --- a/tests/integration/go_ethereum/test_goethereum_http.py +++ b/tests/integration/go_ethereum/test_goethereum_http.py @@ -159,7 +159,45 @@ class TestGoEthereumAsyncNetModuleTest(GoEthereumAsyncNetModuleTest): class TestGoEthereumAsyncEthModuleTest(GoEthereumAsyncEthModuleTest): - pass + @pytest.mark.asyncio + async def test_async_http_provider_disconnects_gracefully( + self, async_w3, endpoint_uri + ) -> None: + w3_1 = async_w3 + + w3_2 = AsyncWeb3(AsyncHTTPProvider(endpoint_uri)) + assert w3_1 != w3_2 + + await w3_1.eth.get_block("latest") + await w3_2.eth.get_block("latest") + + w3_1_session_cache = w3_1.provider._request_session_manager.session_cache + w3_2_session_cache = w3_2.provider._request_session_manager.session_cache + + for _, session in w3_1_session_cache.items(): + assert not session.closed + for _, session in w3_2_session_cache.items(): + assert not session.closed + assert w3_1_session_cache != w3_2_session_cache + + await w3_1.provider.disconnect() + await w3_2.provider.disconnect() + + assert len(w3_1_session_cache) == 0 + assert len(w3_2_session_cache) == 0 + + @pytest.mark.asyncio + async def test_async_http_provider_reuses_cached_session(self, async_w3) -> None: + await async_w3.eth.get_block("latest") + session_cache = async_w3.provider._request_session_manager.session_cache + assert len(session_cache) == 1 + session = list(session_cache._data.values())[0] + + await async_w3.eth.get_block("latest") + assert len(session_cache) == 1 + assert session == list(session_cache._data.values())[0] + await async_w3.provider.disconnect() + assert len(session_cache) == 0 class TestGoEthereumAsyncTxPoolModuleTest(GoEthereumAsyncTxPoolModuleTest): diff --git a/web3/_utils/http_session_manager.py b/web3/_utils/http_session_manager.py index f689d69591..353b40511c 100644 --- a/web3/_utils/http_session_manager.py +++ b/web3/_utils/http_session_manager.py @@ -18,6 +18,7 @@ ClientResponse, ClientSession, ClientTimeout, + TCPConnector, ) from eth_typing import ( URI, @@ -174,7 +175,12 @@ async def async_cache_and_return_session( async with async_lock(self.session_pool, self._lock): if cache_key not in self.session_cache: if session is None: - session = ClientSession(raise_for_status=True) + session = ClientSession( + raise_for_status=True, + connector=TCPConnector( + force_close=True, enable_cleanup_closed=True + ), + ) cached_session, evicted_items = self.session_cache.cache( cache_key, session @@ -213,7 +219,12 @@ async def async_cache_and_return_session( ) # replace stale session with a new session at the cache key - _session = ClientSession(raise_for_status=True) + _session = ClientSession( + raise_for_status=True, + connector=TCPConnector( + force_close=True, enable_cleanup_closed=True + ), + ) cached_session, evicted_items = self.session_cache.cache( cache_key, _session ) diff --git a/web3/providers/rpc/async_rpc.py b/web3/providers/rpc/async_rpc.py index 8f723138eb..bc8ef17d25 100644 --- a/web3/providers/rpc/async_rpc.py +++ b/web3/providers/rpc/async_rpc.py @@ -177,3 +177,11 @@ async def make_batch_request( self.logger.debug("Received batch response HTTP.") responses_list = cast(List[RPCResponse], self.decode_rpc_response(raw_response)) return sort_batch_response_by_response_ids(responses_list) + + async def disconnect(self) -> None: + cache = self._request_session_manager.session_cache + for _, session in cache.items(): + await session.close() + cache.clear() + + self.logger.info(f"Successfully disconnected from: {self.endpoint_uri}")