From 12c108cd0a34824529b83d8f2ebb21f0217a4289 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 6 Oct 2022 10:39:33 -0500 Subject: [PATCH 1/8] recover from git blunder --- starlette/_utils.py | 54 ++++++++++++++++++++++++++++++++++++++++++ starlette/requests.py | 15 ++++++++---- tests/test_requests.py | 13 ++++++++++ 3 files changed, 77 insertions(+), 5 deletions(-) diff --git a/starlette/_utils.py b/starlette/_utils.py index 0710aebdc..d797c39fb 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -1,6 +1,13 @@ import asyncio import functools +import sys import typing +from types import TracebackType + +if sys.version_info < (3, 8): # pragma: no cover + from typing_extensions import Protocol +else: + from typing import Protocol def is_async_callable(obj: typing.Any) -> bool: @@ -10,3 +17,50 @@ def is_async_callable(obj: typing.Any) -> bool: return asyncio.iscoroutinefunction(obj) or ( callable(obj) and asyncio.iscoroutinefunction(obj.__call__) ) + + +T_co = typing.TypeVar("T_co", covariant=True) + + +class AwaitableOrContextManager(Protocol[T_co]): + def __await__(self) -> typing.Generator[typing.Any, None, T_co]: + ... # pragma: no cover + + async def __aenter__(self) -> T_co: + ... # pragma: no cover + + async def __aexit__( + self, + __exc_type: typing.Optional[typing.Type[BaseException]], + __exc_value: typing.Optional[BaseException], + __traceback: typing.Optional[TracebackType], + ) -> typing.Union[bool, None]: + ... # pragma: no cover + + +class SupportsAsyncClose(Protocol): + async def close(self) -> None: + ... # pragma: no cover + + +SupportsAsyncCloseType = typing.TypeVar( + "SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False +) + + +class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]): + __slots__ = ("aw", "entered") + + def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None: + self.aw = aw + + def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]: + return self.aw.__await__() + + async def __aenter__(self) -> SupportsAsyncCloseType: + self.entered = await self.aw + return self.entered + + async def __aexit__(self, *args: typing.Any) -> typing.Union[None, bool]: + await self.entered.close() + return None diff --git a/starlette/requests.py b/starlette/requests.py index 726abddcc..d924b501a 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -4,6 +4,7 @@ import anyio +from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State from starlette.exceptions import HTTPException from starlette.formparsers import FormParser, MultiPartException, MultiPartParser @@ -187,6 +188,8 @@ async def empty_send(message: Message) -> typing.NoReturn: class Request(HTTPConnection): + _form: typing.Optional[FormData] + def __init__( self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send ): @@ -196,6 +199,7 @@ def __init__( self._send = send self._stream_consumed = False self._is_disconnected = False + self._form = None @property def method(self) -> str: @@ -210,10 +214,8 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]: yield self._body yield b"" return - if self._stream_consumed: raise RuntimeError("Stream consumed") - self._stream_consumed = True while True: message = await self._receive() @@ -242,8 +244,8 @@ async def json(self) -> typing.Any: self._json = json.loads(body) return self._json - async def form(self) -> FormData: - if not hasattr(self, "_form"): + async def _get_form(self) -> FormData: + if self._form is None: assert ( parse_options_header is not None ), "The `python-multipart` library must be installed to use form parsing." @@ -265,8 +267,11 @@ async def form(self) -> FormData: self._form = FormData() return self._form + def form(self) -> AwaitableOrContextManager[FormData]: + return AwaitableOrContextManagerWrapper(self._get_form()) + async def close(self) -> None: - if hasattr(self, "_form"): + if self._form is not None: await self._form.close() async def is_disconnected(self) -> bool: diff --git a/tests/test_requests.py b/tests/test_requests.py index 7422ad72a..cbdf478e9 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -128,6 +128,19 @@ async def app(scope, receive, send): assert response.json() == {"form": {"abc": "123 @"}} +def test_request_form_context_manager(test_client_factory): + async def app(scope, receive, send): + request = Request(scope, receive) + async with request.form() as form: + response = JSONResponse({"form": dict(form)}) + await response(scope, receive, send) + + client = test_client_factory(app) + + response = client.post("/", data={"abc": "123 @"}) + assert response.json() == {"form": {"abc": "123 @"}} + + def test_request_body_then_stream(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) From 3c7bdde2bddaeb5dac6c643da4822a72f3bead27 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 6 Oct 2022 10:52:07 -0500 Subject: [PATCH 2/8] skip coverage --- starlette/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/_utils.py b/starlette/_utils.py index d797c39fb..1114e667e 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -6,7 +6,7 @@ if sys.version_info < (3, 8): # pragma: no cover from typing_extensions import Protocol -else: +else: # pragma: no cover from typing import Protocol From 873c1f2322ec1c508264b4cbca634572c3af626e Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 4 Feb 2023 17:16:31 -0800 Subject: [PATCH 3/8] edit usage in docs --- docs/requests.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/requests.md b/docs/requests.md index 747e496d1..08d793a71 100644 --- a/docs/requests.md +++ b/docs/requests.md @@ -81,7 +81,7 @@ There are a few different interfaces for returning the body of the request: The request body as bytes: `await request.body()` -The request body, parsed as form data or multipart: `await request.form()` +The request body, parsed as form data or multipart: `async with request.form() as form:` The request body, parsed as JSON: `await request.json()` @@ -114,7 +114,7 @@ state with `disconnected = await request.is_disconnected()`. Request files are normally sent as multipart form data (`multipart/form-data`). -When you call `await request.form()` you receive a `starlette.datastructures.FormData` which is an immutable +When you call `async with request.form() as form` you receive a `starlette.datastructures.FormData` which is an immutable multidict, containing both file uploads and text input. File upload items are represented as instances of `starlette.datastructures.UploadFile`. `UploadFile` has the following attributes: @@ -137,9 +137,9 @@ As all these methods are `async` methods, you need to "await" them. For example, you can get the file name and the contents with: ```python -form = await request.form() -filename = form["upload_file"].filename -contents = await form["upload_file"].read() +async await request.form() as form: + filename = form["upload_file"].filename + contents = await form["upload_file"].read() ``` !!! info From b0d7ca5dd7528d6ef4b33ad3bf1d15a20a835743 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 5 Feb 2023 09:09:46 +0100 Subject: [PATCH 4/8] Update docs/requests.md --- docs/requests.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requests.md b/docs/requests.md index 08d793a71..1e3cfc46c 100644 --- a/docs/requests.md +++ b/docs/requests.md @@ -137,7 +137,7 @@ As all these methods are `async` methods, you need to "await" them. For example, you can get the file name and the contents with: ```python -async await request.form() as form: +async with request.form() as form: filename = form["upload_file"].filename contents = await form["upload_file"].read() ``` From 42f70182a76abc2db1ac8d7ff210f60d105398bd Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 5 Feb 2023 21:44:49 -0800 Subject: [PATCH 5/8] Update starlette/_utils.py --- starlette/_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/starlette/_utils.py b/starlette/_utils.py index 1114e667e..a9d7659e6 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -22,6 +22,9 @@ def is_async_callable(obj: typing.Any) -> bool: T_co = typing.TypeVar("T_co", covariant=True) +# TODO: once 3.8 is the minimum supported version (27 Jun 2023) +# this can just become +# class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co]): pass class AwaitableOrContextManager(Protocol[T_co]): def __await__(self) -> typing.Generator[typing.Any, None, T_co]: ... # pragma: no cover From 39c2d9ce8b9b29846aed0e6605903b680f193823 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 5 Feb 2023 21:55:18 -0800 Subject: [PATCH 6/8] fix comment --- starlette/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/starlette/_utils.py b/starlette/_utils.py index a9d7659e6..579e05812 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -24,7 +24,8 @@ def is_async_callable(obj: typing.Any) -> bool: # TODO: once 3.8 is the minimum supported version (27 Jun 2023) # this can just become -# class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co]): pass +# class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], Protocol[T_co]): +# pass class AwaitableOrContextManager(Protocol[T_co]): def __await__(self) -> typing.Generator[typing.Any, None, T_co]: ... # pragma: no cover From a1aba6d28bb124065236de9b607f0f2c0ba3bfeb Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 5 Feb 2023 21:56:04 -0800 Subject: [PATCH 7/8] fix comment --- starlette/_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/starlette/_utils.py b/starlette/_utils.py index 579e05812..b3ab71eb4 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -24,7 +24,11 @@ def is_async_callable(obj: typing.Any) -> bool: # TODO: once 3.8 is the minimum supported version (27 Jun 2023) # this can just become -# class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], Protocol[T_co]): +# class AwaitableOrContextManager( +# typing.Awaitable[T_co], +# typing.AsyncContextManager[T_co], +# typing.Protocol[T_co] +# ): # pass class AwaitableOrContextManager(Protocol[T_co]): def __await__(self) -> typing.Generator[typing.Any, None, T_co]: From bf3d7f056678cfb8bbd859fc80f67db997aed6fa Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 5 Feb 2023 21:56:18 -0800 Subject: [PATCH 8/8] trailing comma --- starlette/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/_utils.py b/starlette/_utils.py index b3ab71eb4..d781647ff 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -27,7 +27,7 @@ def is_async_callable(obj: typing.Any) -> bool: # class AwaitableOrContextManager( # typing.Awaitable[T_co], # typing.AsyncContextManager[T_co], -# typing.Protocol[T_co] +# typing.Protocol[T_co], # ): # pass class AwaitableOrContextManager(Protocol[T_co]):