Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests] handler decomposition #125

Merged
merged 4 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dev-dependencies = [
"pytest-mock>=3.11.1",
"ruff>=0.0.278",
"types-protobuf>=4.24.0.20240311",
"types-nanoid>=2.0.0.20240601",
]

[tool.ruff]
Expand Down
82 changes: 82 additions & 0 deletions tests/common_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Any, AsyncGenerator, AsyncIterator, Iterator

import grpc
import grpc.aio

from replit_river.rpc import (
rpc_method_handler,
stream_method_handler,
subscription_method_handler,
upload_method_handler,
)
from tests.conftest import HandlerMapping, deserialize_request, serialize_response


async def rpc_handler(request: str, context: grpc.aio.ServicerContext) -> str:
return f"Hello, {request}!"


basic_rpc_method: HandlerMapping = {
("test_service", "rpc_method"): (
"rpc",
rpc_method_handler(rpc_handler, deserialize_request, serialize_response),
)
}


async def upload_handler(
request: Iterator[str] | AsyncIterator[str], context: Any
) -> str:
uploaded_data = []
if isinstance(request, AsyncIterator):
async for data in request:
uploaded_data.append(data)
else:
for data in request:
uploaded_data.append(data)
return f"Uploaded: {', '.join(uploaded_data)}"


basic_upload: HandlerMapping = {
("test_service", "upload_method"): (
"upload",
upload_method_handler(upload_handler, deserialize_request, serialize_response),
),
}


async def subscription_handler(
request: str, context: grpc.aio.ServicerContext
) -> AsyncGenerator[str, None]:
for i in range(5):
yield f"Subscription message {i} for {request}"


basic_subscription: HandlerMapping = {
("test_service", "subscription_method"): (
"subscription",
subscription_method_handler(
subscription_handler, deserialize_request, serialize_response
),
),
}


async def stream_handler(
request: Iterator[str] | AsyncIterator[str],
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[str, None]:
if isinstance(request, AsyncIterator):
async for data in request:
yield f"Stream response for {data}"
else:
for data in request:
yield f"Stream response for {data}"


basic_stream: HandlerMapping = {
("test_service", "stream_method"): (
"stream",
stream_method_handler(stream_handler, deserialize_request, serialize_response),
),
}
104 changes: 15 additions & 89 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import asyncio
import logging
from collections.abc import AsyncIterator
from typing import Any, AsyncGenerator, Iterator, Literal
from typing import Any, AsyncGenerator, Literal, Mapping

import grpc.aio
import nanoid # type: ignore
import nanoid
import pytest
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
Expand All @@ -14,13 +12,10 @@

from replit_river.client import Client
from replit_river.client_transport import UriAndMetadata
from replit_river.error_schema import RiverError, RiverException
from replit_river.error_schema import RiverError
from replit_river.rpc import (
GenericRpcHandler,
TransportMessage,
rpc_method_handler,
stream_method_handler,
subscription_method_handler,
upload_method_handler,
)
from replit_river.server import Server
from replit_river.transport_options import TransportOptions
Expand All @@ -29,6 +24,8 @@
# Modular fixtures
pytest_plugins = ["tests.river_fixtures.logging"]

HandlerMapping = Mapping[tuple[str, str], tuple[str, GenericRpcHandler]]


def transport_message(
seq: int = 0,
Expand Down Expand Up @@ -71,93 +68,22 @@ def deserialize_error(response: dict) -> RiverError:
return RiverError.model_validate(response)


# RPC method handlers for testing
async def rpc_handler(request: str, context: grpc.aio.ServicerContext) -> str:
return f"Hello, {request}!"


async def subscription_handler(
request: str, context: grpc.aio.ServicerContext
) -> AsyncGenerator[str, None]:
for i in range(5):
yield f"Subscription message {i} for {request}"


async def upload_handler(
request: Iterator[str] | AsyncIterator[str], context: Any
) -> str:
uploaded_data = []
if isinstance(request, AsyncIterator):
async for data in request:
uploaded_data.append(data)
else:
for data in request:
uploaded_data.append(data)
return f"Uploaded: {', '.join(uploaded_data)}"


async def stream_handler(
request: Iterator[str] | AsyncIterator[str],
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[str, None]:
if isinstance(request, AsyncIterator):
async for data in request:
yield f"Stream response for {data}"
else:
for data in request:
yield f"Stream response for {data}"


async def stream_error_handler(
request: Iterator[str] | AsyncIterator[str],
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[str, None]:
raise RiverException("INJECTED_ERROR", "test error")
yield "test" # appease the type checker


@pytest.fixture
def transport_options() -> TransportOptions:
return TransportOptions()


@pytest.fixture
def server(transport_options: TransportOptions) -> Server:
def server_handlers(handlers: HandlerMapping) -> HandlerMapping:
return handlers


@pytest.fixture
def server(
transport_options: TransportOptions, server_handlers: HandlerMapping
) -> Server:
server = Server(server_id="test_server", transport_options=transport_options)
server.add_rpc_handlers(
{
("test_service", "rpc_method"): (
"rpc",
rpc_method_handler(
rpc_handler, deserialize_request, serialize_response
),
),
("test_service", "subscription_method"): (
"subscription",
subscription_method_handler(
subscription_handler, deserialize_request, serialize_response
),
),
("test_service", "upload_method"): (
"upload",
upload_method_handler(
upload_handler, deserialize_request, serialize_response
),
),
("test_service", "stream_method"): (
"stream",
stream_method_handler(
stream_handler, deserialize_request, serialize_response
),
),
("test_service", "stream_method_error"): (
"stream",
stream_method_handler(
stream_error_handler, deserialize_request, serialize_response
),
),
}
)
server.add_rpc_handlers(server_handlers)
return server


Expand Down
20 changes: 19 additions & 1 deletion tests/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,21 @@
from replit_river.client import Client
from replit_river.error_schema import RiverError
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
from tests.conftest import deserialize_error, deserialize_response, serialize_request
from tests.common_handlers import (
basic_rpc_method,
basic_stream,
basic_subscription,
basic_upload,
)
from tests.conftest import (
deserialize_error,
deserialize_response,
serialize_request,
)


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_rpc_method}])
async def test_rpc_method(client: Client) -> None:
response = await client.send_rpc(
"test_service",
Expand All @@ -23,6 +34,7 @@ async def test_rpc_method(client: Client) -> None:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_upload}])
async def test_upload_method(client: Client) -> None:
async def upload_data() -> AsyncGenerator[str, None]:
yield "Data 1"
Expand All @@ -43,6 +55,7 @@ async def upload_data() -> AsyncGenerator[str, None]:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_upload}])
async def test_upload_more_than_send_buffer_max(client: Client) -> None:
iterations = MAX_MESSAGE_BUFFER_SIZE * 2

Expand All @@ -64,6 +77,7 @@ async def upload_data() -> AsyncGenerator[str, None]:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_upload}])
async def test_upload_empty(client: Client) -> None:
async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:
if enabled:
Expand All @@ -83,6 +97,7 @@ async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_subscription}])
async def test_subscription_method(client: Client) -> None:
async for response in client.send_subscription(
"test_service",
Expand All @@ -97,6 +112,7 @@ async def test_subscription_method(client: Client) -> None:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_stream}])
async def test_stream_method(client: Client) -> None:
async def stream_data() -> AsyncGenerator[str, None]:
yield "Stream 1"
Expand Down Expand Up @@ -125,6 +141,7 @@ async def stream_data() -> AsyncGenerator[str, None]:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_stream}])
async def test_stream_empty(client: Client) -> None:
async def stream_data(enabled: bool = False) -> AsyncGenerator[str, None]:
if enabled:
Expand All @@ -147,6 +164,7 @@ async def stream_data(enabled: bool = False) -> AsyncGenerator[str, None]:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_upload, **basic_stream}])
async def test_multiplexing(client: Client) -> None:
async def upload_data() -> AsyncGenerator[str, None]:
yield "Upload Data 1"
Expand Down
1 change: 1 addition & 0 deletions tests/test_handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def transport_options() -> TransportOptions:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{}])
async def test_handshake_timeout(server: Server) -> None:
async with serve(server.serve, "localhost", 8765):
start = time()
Expand Down
Loading
Loading