Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't record GeneratorExit errors in stream RPC spans #129

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 64 additions & 30 deletions replit_river/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from collections.abc import AsyncIterable, Awaitable, Callable
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Generator, Generic, Literal, Optional, Union
from typing import Any, AsyncGenerator, Generator, Generic, Literal, Optional, Union

from opentelemetry import trace
from opentelemetry.trace import Span, SpanKind, StatusCode
from opentelemetry.trace import Span, SpanKind, Status, StatusCode

from replit_river.client_transport import ClientTransport
from replit_river.error_schema import RiverError, RiverException
Expand Down Expand Up @@ -63,7 +64,7 @@ async def send_rpc(
error_deserializer: Callable[[Any], ErrorType],
timeout: timedelta,
) -> ResponseType:
with _trace_procedure("rpc", service_name, procedure_name) as span:
with _trace_procedure("rpc", service_name, procedure_name) as span_handle:
session = await self._transport.get_or_create_session()
return await session.send_rpc(
service_name,
Expand All @@ -72,7 +73,7 @@ async def send_rpc(
request_serializer,
response_deserializer,
error_deserializer,
span,
span_handle.span,
timeout,
)

Expand All @@ -87,7 +88,7 @@ async def send_upload(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> ResponseType:
with _trace_procedure("upload", service_name, procedure_name) as span:
with _trace_procedure("upload", service_name, procedure_name) as span_handle:
session = await self._transport.get_or_create_session()
return await session.send_upload(
service_name,
Expand All @@ -98,7 +99,7 @@ async def send_upload(
request_serializer,
response_deserializer,
error_deserializer,
span,
span_handle.span,
)

async def send_subscription(
Expand All @@ -109,8 +110,10 @@ async def send_subscription(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
with _trace_procedure("subscription", service_name, procedure_name) as span:
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
with _trace_procedure(
"subscription", service_name, procedure_name
) as span_handle:
session = await self._transport.get_or_create_session()
async for msg in session.send_subscription(
service_name,
Expand All @@ -119,10 +122,10 @@ async def send_subscription(
request_serializer,
response_deserializer,
error_deserializer,
span,
span_handle.span,
):
if isinstance(msg, RiverError):
_record_river_error(span, msg)
_record_river_error(span_handle, msg)
yield msg # type: ignore # https://github.com/python/mypy/issues/10817

async def send_stream(
Expand All @@ -135,8 +138,8 @@ async def send_stream(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
with _trace_procedure("stream", service_name, procedure_name) as span:
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
with _trace_procedure("stream", service_name, procedure_name) as span_handle:
session = await self._transport.get_or_create_session()
async for msg in session.send_stream(
service_name,
Expand All @@ -147,32 +150,63 @@ async def send_stream(
request_serializer,
response_deserializer,
error_deserializer,
span,
span_handle.span,
):
if isinstance(msg, RiverError):
_record_river_error(span, msg)
_record_river_error(span_handle, msg)
yield msg # type: ignore # https://github.com/python/mypy/issues/10817


@dataclass
class _SpanHandle:
"""Wraps a span and keeps track of whether or not a status has been recorded yet."""

span: Span
did_set_status: bool = False

def set_status(
self,
status: Union[Status, StatusCode],
description: Optional[str] = None,
) -> None:
if self.did_set_status:
return
self.did_set_status = True
self.span.set_status(status, description)


@contextmanager
def _trace_procedure(
procedure_type: Literal["rpc", "upload", "subscription", "stream"],
service_name: str,
procedure_name: str,
) -> Generator[Span, None, None]:
with tracer.start_span(
) -> Generator[_SpanHandle, None, None]:
span = tracer.start_span(
f"river.client.{procedure_type}.{service_name}.{procedure_name}",
kind=SpanKind.CLIENT,
) as span:
try:
yield span
except RiverException as e:
_record_river_error(span, RiverError(code=e.code, message=e.message))
raise e


def _record_river_error(span: Span, error: RiverError) -> None:
span.set_status(StatusCode.ERROR, error.message)
span.record_exception(RiverException(error.code, error.message))
span.set_attribute("river.error_code", error.code)
span.set_attribute("river.error_message", error.message)
)
span_handle = _SpanHandle(span)
try:
yield span_handle
except GeneratorExit:
# This error indicates the caller is done with the async generator
# but messages are still left. This is okay, we do not consider it an error.
raise
except RiverException as e:
span.record_exception(e, escaped=True)
_record_river_error(span_handle, RiverError(code=e.code, message=e.message))
raise e
except BaseException as e:
span.record_exception(e, escaped=True)
span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}")
raise e
finally:
span_handle.set_status(StatusCode.OK)
span.end()


def _record_river_error(span_handle: _SpanHandle, error: RiverError) -> None:
span_handle.set_status(StatusCode.ERROR, error.message)
span_handle.span.record_exception(RiverException(error.code, error.message))
span_handle.span.set_attribute("river.error_code", error.code)
span_handle.span.set_attribute("river.error_message", error.message)
8 changes: 4 additions & 4 deletions replit_river/client_session.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import logging
from collections.abc import AsyncIterable, AsyncIterator
from collections.abc import AsyncIterable
from datetime import timedelta
from typing import Any, Callable, Optional, Union
from typing import Any, AsyncGenerator, Callable, Optional, Union

import nanoid # type: ignore
from aiochannel import Channel
Expand Down Expand Up @@ -194,7 +194,7 @@ async def send_subscription(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
"""Sends a subscription request to the server.

Expects the input and output be messages that will be msgpacked.
Expand Down Expand Up @@ -246,7 +246,7 @@ async def send_stream(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
"""Sends a subscription request to the server.

Expects the input and output be messages that will be msgpacked.
Expand Down
37 changes: 37 additions & 0 deletions tests/test_opentelemetry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
from datetime import timedelta
from typing import AsyncGenerator, AsyncIterator, Iterator

Expand Down Expand Up @@ -182,3 +183,39 @@ async def stream_data() -> AsyncGenerator[str, None]:
assert len(spans) == 1
assert spans[0].name == "river.client.stream.test_service.stream_method_error"
assert spans[0].status.status_code == StatusCode.ERROR


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_stream}])
async def test_stream_method_span_generator_exit_not_recorded(
client: Client, span_exporter: InMemorySpanExporter
) -> None:
async def stream_data() -> AsyncGenerator[str, None]:
yield "Stream 1"
yield "Stream 2"
yield "Stream 3"

responses = []
stream = client.send_stream(
"test_service",
"stream_method",
"Initial Stream Data",
stream_data(),
serialize_request,
serialize_request,
deserialize_response,
deserialize_error,
)
async with contextlib.aclosing(stream) as generator:
async for response in generator:
responses.append(response)
break

assert responses == [
"Stream response for Initial Stream Data",
]

spans = span_exporter.get_finished_spans()
assert len(spans) == 1
assert spans[0].name == "river.client.stream.test_service.stream_method"
assert spans[0].status.status_code == StatusCode.OK
Loading