Skip to content

Commit

Permalink
Drop the use of ExceptionGroup as mypy isn't ready
Browse files Browse the repository at this point in the history
This partially reverts 35f1a54
  • Loading branch information
fantix committed Sep 16, 2022
1 parent b9f4eeb commit c8b36bb
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 50 deletions.
34 changes: 18 additions & 16 deletions edb/common/signalctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ async def wait_for(self, fut, *, cancel_on=None):
if cancel_on is None:
cancel_on = self._signals

cancelled_by = []
cancelled_by = None
outer_cancelled_at_last = False

# The design here: we'll wait on a separate Future "waiter" for clean
Expand All @@ -122,7 +122,7 @@ async def wait_for(self, fut, *, cancel_on=None):
# 2. A signal is captured
# 3. The "waiter" is cancelled by outer code.
# For 2, we'll cancel the given "fut" and record the signal in
# cancelled_by to raise as an ExceptionGroup in the next step; for 3,
# cancelled_by as a __context__ chain to raise in the next step; for 3,
# we cancel the given "fut" and propagate the CancelledError later.
#
# The complexity of this design is: because our cancellation might be
Expand All @@ -132,7 +132,7 @@ async def wait_for(self, fut, *, cancel_on=None):
# are exhaustively executed until the "fut" is done, meanwhile the
# signals may keep hitting the "fut" code blocks, and "wait_for" is
# ready to handle them properly, and return all the SignalError objects
# in an ExceptionGroup preserving the order as they happen.
# in a __context__ chain preserving the order as they happen.
while not fut.done():
waiter = self._loop.create_future()
cb = functools.partial(_release_waiter, waiter)
Expand All @@ -154,7 +154,9 @@ async def wait_for(self, fut, *, cancel_on=None):
if not fut.done():
assert signal is not None
fut.cancel()
cancelled_by.append(SignalError(signal))
err = SignalError(signal)
err.__context__ = cancelled_by
cancelled_by = err
outer_cancelled_at_last = False

# Event 1: "fut" is done - exit the loop naturally.
Expand All @@ -180,34 +182,34 @@ async def wait_for(self, fut, *, cancel_on=None):
except asyncio.CancelledError as ex:
# 2. "fut" is cancelled - this usually means we caught a signal,
# but it could also be other reasons, see below.
if cancelled_by:
# Event 2 happened at least once, prepare an ExceptionGroup.
eg = ExceptionGroup("signal", cancelled_by)
if cancelled_by is not None:
# Event 2 happened at least once
if outer_cancelled_at_last:
# If event 3 is the last event, the outer code is probably
# expecting a CancelledError, e.g. asyncio.wait_for().
# Therefore, we just raise it with signal errors attached.
raise ex from eg
ex.__context__ = cancelled_by
rv = ex
else:
# If event 2 is the last event, simply raise the grouped
# signal errors, attaching the CancelledError to reveal
# where the signals hit the user code.
raise eg from ex
cancelled_by.__cause__ = ex
rv = cancelled_by
else:
# Neither event 2 nor 3 happened, the user code cancelled
# itself, simply propagate the same error.
raise

except Exception as e:
# 3. For any other errors, we'll merge it with recorded signal
# errors and raise as an ExceptionGroup, if event 2 happened.
if cancelled_by:
eg = ExceptionGroup("signal", [e, *cancelled_by])
# Not raising here - `eg` already contains `e`, we don't need
# to raise eg from e to include e again.
# 3. For any other errors, we just raise it with the signal errors
# attached as __context__ if event 2 happened.
if cancelled_by is not None:
e.__context__ = cancelled_by
rv = e
else:
raise
raise eg
raise rv

async def wait_for_signals(self):
waiter = QueueWaiter()
Expand Down
5 changes: 2 additions & 3 deletions edb/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,8 @@ async def _run_server(

try:
await sc.wait_for(ss.serve_forever())
except* signalctl.SignalError as eg:
for e in eg.exceptions:
logger.info('Received signal: %s.', e.signo)
except signalctl.SignalError as e:
logger.info('Received signal: %s.', e.signo)
finally:
service_manager.sd_notify('STOPPING=1')
logger.info('Shutting down.')
Expand Down
12 changes: 0 additions & 12 deletions edb/testbase/proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#

import asyncio
import contextlib
import sys
import unittest

Expand All @@ -26,8 +25,6 @@

exec(sys.argv[1], globals(), locals())

from edb.common import signalctl


class ProcTest(server.TestCase):
def notify_parent(self, mark):
Expand All @@ -39,15 +36,6 @@ async def wait_for_parent(self, mark):
str(mark).encode(),
)

@contextlib.contextmanager
def assertRaisesSignals(self, *signals):
try:
yield
except* signalctl.SignalError as eg:
self.assertEqual([e.signo for e in eg.exceptions], list(signals))
else:
self.fail("signalctl.SignalError not raised")

@classmethod
def setUpClass(cls):
super().setUpClass()
Expand Down
40 changes: 21 additions & 19 deletions tests/common/test_signalctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ async def spawn(test_prog, global_prog=""):
+ textwrap.dedent(
"""\
import signal
from edb.common import signalctl
"""
),
textwrap.dedent(test_prog),
Expand Down Expand Up @@ -93,7 +94,7 @@ async def task():
await asyncio.sleep(1)
with signalctl.SignalController(signal.SIGTERM) as sc:
with self.assertRaisesSignals(signal.SIGTERM):
with self.assertRaisesRegex(signalctl.SignalError, "SIGTERM"):
await sc.wait_for(task(), cancel_on={signal.SIGTERM})
self.notify_parent(2)
Expand All @@ -117,7 +118,7 @@ async def task():
await asyncio.sleep(1)
with signalctl.SignalController(signal.SIGINT) as sc:
with self.assertRaisesSignals(signal.SIGINT):
with self.assertRaisesRegex(signalctl.SignalError, "SIGINT"):
await sc.wait_for(task(), cancel_on={signal.SIGINT})
self.notify_parent(2)
Expand Down Expand Up @@ -147,7 +148,7 @@ async def task():
with signalctl.SignalController(
signal.SIGTERM, signal.SIGINT
) as sc:
with self.assertRaisesSignals(signal.SIGTERM):
with self.assertRaisesRegex(signalctl.SignalError, "SIGTERM"):
await sc.wait_for(task(), cancel_on={signal.SIGTERM})
self.notify_parent(2)
Expand Down Expand Up @@ -181,7 +182,7 @@ async def task():
with signalctl.SignalController(
signal.SIGTERM, signal.SIGINT
) as sc:
with self.assertRaisesSignals(signal.SIGINT):
with self.assertRaisesRegex(signalctl.SignalError, "SIGINT"):
await sc.wait_for(task(), cancel_on={signal.SIGINT})
self.notify_parent(2)
Expand Down Expand Up @@ -212,7 +213,7 @@ async def task():
sc.wait_for(self.wait_for_parent(1)), 0.1
)
with self.assertRaisesSignals(signal.SIGTERM):
with self.assertRaisesRegex(signalctl.SignalError, "SIGTERM"):
await sc.wait_for(task(), cancel_on={signal.SIGTERM})
self.notify_parent(4)
Expand Down Expand Up @@ -358,7 +359,7 @@ async def _task():
tg.create_task(_subtask2())
with signalctl.SignalController(signal.SIGTERM) as sc:
with self.assertRaisesSignals(signal.SIGTERM):
with self.assertRaisesRegex(signalctl.SignalError, "SIGTERM"):
await sc.wait_for(_task())
"""

Expand All @@ -382,7 +383,7 @@ async def task():
self.notify_parent(3)
with signalctl.SignalController(signal.SIGTERM) as sc:
with self.assertRaisesSignals(signal.SIGTERM):
with self.assertRaisesRegex(signalctl.SignalError, "SIGTERM"):
await sc.wait_for(task())
self.notify_parent(4)
"""
Expand Down Expand Up @@ -411,10 +412,12 @@ async def task():
with signalctl.SignalController(
signal.SIGTERM, signal.SIGINT, signal.SIGUSR1
) as sc:
with self.assertRaisesSignals(
signal.SIGTERM, signal.SIGINT, signal.SIGUSR1
):
with self.assertRaises(signalctl.SignalError) as ctx:
await sc.wait_for(task())
ex = ctx.exception
self.assertEqual(ex.signo, signal.SIGUSR1)
self.assertEqual(ex.__context__.signo, signal.SIGINT)
self.assertEqual(ex.__context__.__context__.signo, signal.SIGTERM)
self.notify_parent(7)
"""

Expand Down Expand Up @@ -451,12 +454,10 @@ async def _task():
await fut
await asyncio.wait_for(task, 0.1)
ex = ctx.exception
while not isinstance(ex, ExceptionGroup):
ex = getattr(ex, '__cause__', ex.__context__)
self.assertEqual(
[e.signo for e in ex.exceptions],
[signal.SIGTERM, signal.SIGINT],
)
while not isinstance(ex, signalctl.SignalError):
ex = ex.__context__
self.assertEqual(ex.signo, signal.SIGINT)
self.assertEqual(ex.__context__.signo, signal.SIGTERM)
self.notify_parent(5)
"""

Expand Down Expand Up @@ -486,12 +487,13 @@ async def _task():
with signalctl.SignalController(
signal.SIGTERM, signal.SIGINT
) as sc:
with self.assertRaisesSignals(
signal.SIGTERM, signal.SIGINT
) as ctx:
with self.assertRaises(signalctl.SignalError) as ctx:
task = self.loop.create_task(sc.wait_for(_task()))
await fut
await asyncio.wait_for(task, 0.1)
ex = ctx.exception
self.assertEqual(ex.signo, signal.SIGINT)
self.assertEqual(ex.__context__.signo, signal.SIGTERM)
self.notify_parent(5)
"""

Expand Down

0 comments on commit c8b36bb

Please sign in to comment.