Skip to content

Commit f96f45c

Browse files
committed
Replace BaseHTTPMiddleware with much faster ASGI equalivalent
1 parent 068e108 commit f96f45c

8 files changed

+71
-60
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ python-multipart==0.0.5
2020
rdkit==2022.3.5
2121
requests==2.28.1
2222
SQLAlchemy==1.4.45
23+
starlette-context==0.3.5
2324
urllib3==1.26.9
2425
uvicorn[standard]==0.20.0
2526
yamlreader==3.0.4

tdp_core/dbmanager.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from . import manager
99
from .dbview import DBConnector
1010
from .middleware.close_web_sessions_middleware import CloseWebSessionsMiddleware
11-
from .middleware.request_context_middleware import get_request
11+
from .middleware.request_context_plugin import get_request
1212

1313
_log = logging.getLogger(__name__)
1414

@@ -93,11 +93,14 @@ def create_web_session(self, engine_or_id: Union[Engine, str]) -> Session:
9393
"""
9494
session = self.create_session(engine_or_id)
9595

96+
r = get_request()
97+
if not r:
98+
raise Exception("No request found, did you use a create_web_sesssion outside of a request?")
9699
try:
97-
existing_sessions = get_request().state.db_sessions
100+
existing_sessions = r.state.db_sessions
98101
except (KeyError, AttributeError):
99102
existing_sessions = []
100-
get_request().state.db_sessions = existing_sessions
103+
r.state.db_sessions = existing_sessions
101104
existing_sessions.append(session)
102105

103106
return session
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
1-
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
2-
from starlette.requests import Request
1+
from fastapi import FastAPI
32

3+
from .request_context_plugin import get_request
44

5-
class CloseWebSessionsMiddleware(BaseHTTPMiddleware):
6-
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
7-
response = await call_next(request)
85

9-
try:
10-
for db_session in request.state.db_sessions:
11-
try:
12-
db_session.close()
13-
except Exception:
14-
pass
15-
except (KeyError, AttributeError):
16-
pass
6+
# Use basic ASGI middleware instead of BaseHTTPMiddleware as it is significantly faster: https://github.com/tiangolo/fastapi/issues/2696#issuecomment-768224643
7+
class CloseWebSessionsMiddleware:
8+
def __init__(self, app: FastAPI):
9+
self.app = app
1710

18-
return response
11+
async def __call__(self, scope, receive, send):
12+
await self.app(scope, receive, send)
13+
14+
r = get_request()
15+
if r:
16+
try:
17+
for db_session in r.state.db_sessions:
18+
try:
19+
db_session.close()
20+
except Exception:
21+
pass
22+
except (KeyError, AttributeError):
23+
pass
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
import logging
22

3-
from fastapi import HTTPException
3+
from fastapi import FastAPI, HTTPException
44
from fastapi.exception_handlers import http_exception_handler
5-
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
6-
from starlette.requests import Request
75

86
from ..server.utils import detail_from_exception
7+
from .request_context_plugin import get_request
98

109

11-
class ExceptionHandlerMiddleware(BaseHTTPMiddleware):
12-
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
10+
# Use basic ASGI middleware instead of BaseHTTPMiddleware as it is significantly faster: https://github.com/tiangolo/fastapi/issues/2696#issuecomment-768224643
11+
class ExceptionHandlerMiddleware:
12+
def __init__(self, app: FastAPI):
13+
self.app = app
14+
15+
async def __call__(self, scope, receive, send):
1316
try:
14-
return await call_next(request)
17+
await self.app(scope, receive, send)
1518
except Exception as e:
1619
logging.exception("An error occurred in FastAPI")
1720
return await http_exception_handler(
18-
request,
21+
get_request(), # type: ignore
1922
e if isinstance(e, HTTPException) else HTTPException(status_code=500, detail=detail_from_exception(e)),
2023
)

tdp_core/middleware/request_context_middleware.py

-21
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Optional
2+
3+
from starlette.requests import HTTPConnection, Request
4+
from starlette_context import context
5+
from starlette_context.plugins.base import Plugin
6+
7+
8+
def get_request() -> Request | None:
9+
return context.get("request")
10+
11+
12+
class RequestContextPlugin(Plugin):
13+
# The returned value will be inserted in the context with this key
14+
key = "request"
15+
16+
async def process_request(self, request: Request | HTTPConnection) -> Optional[Request | HTTPConnection]:
17+
return request

tdp_core/security/manager.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from fastapi.security.utils import get_authorization_scheme_param
1010

1111
from .. import manager
12-
from ..middleware.request_context_middleware import get_request
12+
from ..middleware.request_context_plugin import get_request
1313
from .model import ANONYMOUS_USER, LogoutReturnValue, User
1414
from .store.base_store import BaseStore
1515

@@ -119,17 +119,18 @@ def _delegate_stores_until_not_none(self, store_method_name: str, *args):
119119
@property
120120
def current_user(self) -> Optional[User]:
121121
try:
122-
req = get_request()
123-
# Fetch the existing user from the request if there is any
124-
try:
125-
user = req.state.user
126-
if user:
127-
return user
128-
except (KeyError, AttributeError):
129-
pass
130-
# If there is no user, try to load it from the request and store it in the request
131-
user = req.state.user = self.load_from_request(get_request())
132-
return user
122+
r = get_request()
123+
if r:
124+
# Fetch the existing user from the request if there is any
125+
try:
126+
user = r.state.user
127+
if user:
128+
return user
129+
except (KeyError, AttributeError):
130+
pass
131+
# If there is no user, try to load it from the request and store it in the request
132+
user = r.state.user = self.load_from_request(r)
133+
return user
133134
except HTTPException:
134135
return None
135136
except Exception:

tdp_core/server/visyn_server.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from fastapi.middleware.wsgi import WSGIMiddleware
99
from pydantic import create_model
1010
from pydantic.utils import deep_update
11+
from starlette_context.middleware import RawContextMiddleware
1112

1213
from ..settings.constants import default_logging_dict
1314

@@ -60,7 +61,6 @@ def create_visyn_server(
6061
)
6162

6263
from ..middleware.exception_handler_middleware import ExceptionHandlerMiddleware
63-
from ..middleware.request_context_middleware import RequestContextMiddleware
6464

6565
# TODO: For some reason, a @app.exception_handler(Exception) is not called here. We use a middleware instead.
6666
app.add_middleware(ExceptionHandlerMiddleware)
@@ -143,8 +143,10 @@ def create_visyn_server(
143143
for p in plugins:
144144
p.plugin.init_app(app)
145145

146-
# Add middleware to access Request "outside"
147-
app.add_middleware(RequestContextMiddleware)
146+
from ..middleware.request_context_plugin import RequestContextPlugin
147+
148+
# Use starlette-context to store the current request globally, i.e. accessible via context['request']
149+
app.add_middleware(RawContextMiddleware, plugins=(RequestContextPlugin(),))
148150

149151
# TODO: Move up?
150152
app.add_api_route("/health", health) # type: ignore

0 commit comments

Comments
 (0)