-
-
Notifications
You must be signed in to change notification settings - Fork 6.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Don't log duplicate error stacktrace for every request in the batch #9023
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
e1c161e
[Frontend] Don't log duplicate error stacktrace for every request in β¦
wallashss f0f2920
[Frontend] improved test
wallashss a0c0532
[Frontend] assert MQEngineDeadError on test_batch_error
wallashss 68fbb83
Merge remote-tracking branch 'wallashss/main' into dont_duplicate_err
wallashss File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,15 +59,7 @@ async def test_evil_forward(tmp_socket): | |
await asyncio.sleep(2.0) | ||
await client.check_health() | ||
|
||
# Throws an error in first forward pass. | ||
with pytest.raises(RAISED_ERROR): | ||
async for _ in client.generate(prompt="Hello my name is", | ||
sampling_params=SamplingParams(), | ||
request_id=uuid.uuid4()): | ||
pass | ||
assert client.errored | ||
|
||
# Engine is errored, should get ENGINE_DEAD_ERROR. | ||
# Throws an error that should get ENGINE_DEAD_ERROR. | ||
with pytest.raises(MQEngineDeadError): | ||
async for _ in client.generate(prompt="Hello my name is", | ||
sampling_params=SamplingParams(), | ||
|
@@ -149,7 +141,7 @@ async def test_failed_abort(tmp_socket): | |
client = await engine.make_client() | ||
assert client.is_running | ||
|
||
# Firsh check health should work. | ||
# First check health should work. | ||
await client.check_health() | ||
|
||
# Trigger an abort on the client side. | ||
|
@@ -174,6 +166,45 @@ async def test_failed_abort(tmp_socket): | |
client.close() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_batch_error(tmp_socket): | ||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, | ||
ipc_path=tmp_socket, | ||
run_fn=run_with_evil_abort) as engine: | ||
|
||
client = await engine.make_client() | ||
assert client.is_running | ||
|
||
# First check health should work. | ||
await client.check_health() | ||
|
||
# Batch of requests | ||
async def do_generate(client): | ||
# min_tokens=2048 to keep busy the engine busy | ||
# to get enough time to get process a request | ||
# that will crash the engine | ||
params = SamplingParams(min_tokens=2048, max_tokens=2048) | ||
async for _ in client.generate(prompt="Hello my name is", | ||
sampling_params=params, | ||
request_id=uuid.uuid4()): | ||
pass | ||
|
||
tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)] | ||
|
||
# This request will force a processing batch to raise | ||
# an exception and next the engine get errored | ||
await client.abort(request_id="foo") | ||
|
||
# The batch of those request failed, then they | ||
# should get the same exception as a MQEngineDeadError. | ||
errors = await asyncio.gather(*tasks, return_exceptions=True) | ||
for e in errors: | ||
assert isinstance(e, MQEngineDeadError) | ||
assert "KeyError" in repr(e) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wallashss I think we need to assert that these errors are also |
||
|
||
client.close() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_bad_request(tmp_socket): | ||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -204,8 +204,20 @@ async def run_output_handler_loop(self): | |
# (and record only the first one) | ||
if is_engine_errored and not self._errored_with: | ||
self._errored_with = exception | ||
# If engine is errored, no matter the type of exception | ||
# it will no longer be able to receive new requests, | ||
# therefore we have to inform that the current | ||
# processed requests failed as well. Send back a dead | ||
# engine error give this feedback and also give a | ||
# 'hint' to the server to shutdown next. | ||
exception = self.dead_error | ||
|
||
if request_id is None: | ||
# If request_id is None, then the engine raised an | ||
# exception for a batch, and we may not know the | ||
# request that caused it, neither if it was actually | ||
# caused by any of them (e.g. CUDA OOM). Therefore we | ||
# broadcast the same exception for all requests. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, love the explanation here! |
||
for queue_i in tuple(self.output_queues.values()): | ||
queue_i.put_nowait(exception) | ||
else: | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also check a "batch" of requests here, like
That should test that we don't get the big spew of stack traces, since every request will raise an error type that doesn't log the stack trace