Skip to content

Commit 82da8ee

Browse files
committed
Merge branch 'startup-process-twice' of https://github.com/sambarza/cheshire-cat-ai-core into sambarza-startup-process-twice
2 parents ca0367c + 7e8792f commit 82da8ee

File tree

3 files changed

+137
-127
lines changed

3 files changed

+137
-127
lines changed

core/cat/main.py

+5-126
Original file line numberDiff line numberDiff line change
@@ -1,134 +1,13 @@
11
import uvicorn
2-
import asyncio
3-
from contextlib import asynccontextmanager
4-
from scalar_fastapi import get_scalar_api_reference
52

6-
from fastapi import FastAPI
7-
from fastapi.routing import APIRoute
8-
from fastapi.responses import JSONResponse
9-
from fastapi.exceptions import RequestValidationError
10-
from fastapi.middleware.cors import CORSMiddleware
11-
12-
from cat.log import log
133
from cat.env import get_env, fix_legacy_env_variables
14-
from cat.routes import (
15-
base,
16-
auth,
17-
users,
18-
settings,
19-
llm,
20-
embedder,
21-
auth_handler,
22-
plugins,
23-
upload,
24-
websocket,
25-
)
26-
from cat.routes.memory.memory_router import memory_router
27-
from cat.routes.static import admin, static
28-
from cat.routes.openapi import get_openapi_configuration_function
29-
from cat.looking_glass.cheshire_cat import CheshireCat
30-
31-
32-
# TODO: take away in v2
33-
fix_legacy_env_variables()
34-
35-
36-
@asynccontextmanager
37-
async def lifespan(app: FastAPI):
38-
# ^._.^
39-
#
40-
# loads Cat and plugins
41-
# Every endpoint can access the cat instance via request.app.state.ccat
42-
# - Not using midlleware because I can't make it work with both http and websocket;
43-
# - Not using Depends because it only supports callables (not instances)
44-
# - Starlette allows this: https://www.starlette.io/applications/#storing-state-on-the-app-instance
45-
app.state.ccat = CheshireCat()
46-
47-
# Dict of pseudo-sessions (key is the user_id)
48-
app.state.strays = {}
49-
50-
# set a reference to asyncio event loop
51-
app.state.event_loop = asyncio.get_running_loop()
52-
53-
# startup message with admin, public and swagger addresses
54-
log.welcome()
55-
56-
yield
57-
58-
59-
def custom_generate_unique_id(route: APIRoute):
60-
return f"{route.name}"
61-
62-
63-
# REST API
64-
cheshire_cat_api = FastAPI(
65-
lifespan=lifespan, generate_unique_id_function=custom_generate_unique_id,
66-
docs_url=None, redoc_url=None, title="Cheshire-Cat API",
67-
license_info={"name": "GPL-3", "url": "https://www.gnu.org/licenses/gpl-3.0.en.html"},
68-
)
69-
70-
# Configures the CORS middleware for the FastAPI app
71-
cors_allowed_origins_str = get_env("CCAT_CORS_ALLOWED_ORIGINS")
72-
origins = cors_allowed_origins_str.split(",") if cors_allowed_origins_str else ["*"]
73-
cheshire_cat_api.add_middleware(
74-
CORSMiddleware,
75-
allow_origins=origins,
76-
allow_credentials=True,
77-
allow_methods=["*"],
78-
allow_headers=["*"],
79-
)
80-
81-
# Add routers to the middleware stack.
82-
cheshire_cat_api.include_router(base.router, tags=["Home"])
83-
cheshire_cat_api.include_router(auth.router, tags=["User Auth"], prefix="/auth")
84-
cheshire_cat_api.include_router(users.router, tags=["Users"], prefix="/users")
85-
cheshire_cat_api.include_router(settings.router, tags=["Settings"], prefix="/settings")
86-
cheshire_cat_api.include_router(
87-
llm.router, tags=["Large Language Model"], prefix="/llm"
88-
)
89-
cheshire_cat_api.include_router(embedder.router, tags=["Embedder"], prefix="/embedder")
90-
cheshire_cat_api.include_router(plugins.router, tags=["Plugins"], prefix="/plugins")
91-
cheshire_cat_api.include_router(memory_router, prefix="/memory")
92-
cheshire_cat_api.include_router(
93-
upload.router, tags=["Rabbit Hole"], prefix="/rabbithole"
94-
)
95-
cheshire_cat_api.include_router(
96-
auth_handler.router, tags=["AuthHandler"], prefix="/auth_handler"
97-
)
98-
cheshire_cat_api.include_router(websocket.router, tags=["Websocket"])
99-
100-
# mount static files
101-
# this cannot be done via fastapi.APIrouter:
102-
# https://github.com/tiangolo/fastapi/discussions/9070
103-
104-
# admin single page app (static build)
105-
admin.mount(cheshire_cat_api)
106-
# static files (for plugins and other purposes)
107-
static.mount(cheshire_cat_api)
108-
109-
110-
# error handling
111-
@cheshire_cat_api.exception_handler(RequestValidationError)
112-
async def validation_exception_handler(request, exc):
113-
return JSONResponse(
114-
status_code=400,
115-
content={"error": exc.errors()},
116-
)
117-
118-
119-
# openapi customization
120-
cheshire_cat_api.openapi = get_openapi_configuration_function(cheshire_cat_api)
121-
122-
@cheshire_cat_api.get("/docs", include_in_schema=False)
123-
async def scalar_docs():
124-
return get_scalar_api_reference(
125-
openapi_url=cheshire_cat_api.openapi_url,
126-
title=cheshire_cat_api.title,
127-
scalar_favicon_url="https://cheshirecat.ai/wp-content/uploads/2023/10/Logo-Cheshire-Cat.svg",
128-
)
1294

1305
# RUN!
1316
if __name__ == "__main__":
7+
8+
# TODO: take away in v2
9+
fix_legacy_env_variables()
10+
13211
# debugging utilities, to deactivate put `DEBUG=false` in .env
13312
debug_config = {}
13413
if get_env("CCAT_DEBUG") == "true":
@@ -146,7 +25,7 @@ async def scalar_docs():
14625
}
14726

14827
uvicorn.run(
149-
"cat.main:cheshire_cat_api",
28+
"cat.startup:cheshire_cat_api",
15029
host="0.0.0.0",
15130
port=80,
15231
use_colors=True,

core/cat/startup.py

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import asyncio
2+
from contextlib import asynccontextmanager
3+
from scalar_fastapi import get_scalar_api_reference
4+
5+
from fastapi import FastAPI
6+
from fastapi.routing import APIRoute
7+
from fastapi.responses import JSONResponse
8+
from fastapi.exceptions import RequestValidationError
9+
from fastapi.middleware.cors import CORSMiddleware
10+
11+
from cat.log import log
12+
from cat.env import get_env
13+
from cat.routes import (
14+
base,
15+
auth,
16+
users,
17+
settings,
18+
llm,
19+
embedder,
20+
auth_handler,
21+
plugins,
22+
upload,
23+
websocket,
24+
)
25+
from cat.routes.memory.memory_router import memory_router
26+
from cat.routes.static import admin, static
27+
from cat.routes.openapi import get_openapi_configuration_function
28+
from cat.looking_glass.cheshire_cat import CheshireCat
29+
30+
31+
@asynccontextmanager
32+
async def lifespan(app: FastAPI):
33+
34+
# ^._.^
35+
#
36+
# loads Cat and plugins
37+
# Every endpoint can access the cat instance via request.app.state.ccat
38+
# - Not using midlleware because I can't make it work with both http and websocket;
39+
# - Not using Depends because it only supports callables (not instances)
40+
# - Starlette allows this: https://www.starlette.io/applications/#storing-state-on-the-app-instance
41+
app.state.ccat = CheshireCat()
42+
43+
# Dict of pseudo-sessions (key is the user_id)
44+
app.state.strays = {}
45+
46+
# set a reference to asyncio event loop
47+
app.state.event_loop = asyncio.get_running_loop()
48+
49+
# startup message with admin, public and swagger addresses
50+
log.welcome()
51+
52+
yield
53+
54+
55+
def custom_generate_unique_id(route: APIRoute):
56+
return f"{route.name}"
57+
58+
59+
# REST API
60+
cheshire_cat_api = FastAPI(
61+
lifespan=lifespan,
62+
generate_unique_id_function=custom_generate_unique_id,
63+
docs_url=None,
64+
redoc_url=None,
65+
title="Cheshire-Cat API",
66+
license_info={
67+
"name": "GPL-3",
68+
"url": "https://www.gnu.org/licenses/gpl-3.0.en.html",
69+
},
70+
)
71+
72+
# Configures the CORS middleware for the FastAPI app
73+
cors_allowed_origins_str = get_env("CCAT_CORS_ALLOWED_ORIGINS")
74+
origins = cors_allowed_origins_str.split(",") if cors_allowed_origins_str else ["*"]
75+
cheshire_cat_api.add_middleware(
76+
CORSMiddleware,
77+
allow_origins=origins,
78+
allow_credentials=True,
79+
allow_methods=["*"],
80+
allow_headers=["*"],
81+
)
82+
83+
# Add routers to the middleware stack.
84+
cheshire_cat_api.include_router(base.router, tags=["Home"])
85+
cheshire_cat_api.include_router(auth.router, tags=["User Auth"], prefix="/auth")
86+
cheshire_cat_api.include_router(users.router, tags=["Users"], prefix="/users")
87+
cheshire_cat_api.include_router(settings.router, tags=["Settings"], prefix="/settings")
88+
cheshire_cat_api.include_router(
89+
llm.router, tags=["Large Language Model"], prefix="/llm"
90+
)
91+
cheshire_cat_api.include_router(embedder.router, tags=["Embedder"], prefix="/embedder")
92+
cheshire_cat_api.include_router(plugins.router, tags=["Plugins"], prefix="/plugins")
93+
cheshire_cat_api.include_router(memory_router, prefix="/memory")
94+
cheshire_cat_api.include_router(
95+
upload.router, tags=["Rabbit Hole"], prefix="/rabbithole"
96+
)
97+
cheshire_cat_api.include_router(
98+
auth_handler.router, tags=["AuthHandler"], prefix="/auth_handler"
99+
)
100+
cheshire_cat_api.include_router(websocket.router, tags=["Websocket"])
101+
102+
# mount static files
103+
# this cannot be done via fastapi.APIrouter:
104+
# https://github.com/tiangolo/fastapi/discussions/9070
105+
106+
# admin single page app (static build)
107+
admin.mount(cheshire_cat_api)
108+
# static files (for plugins and other purposes)
109+
static.mount(cheshire_cat_api)
110+
111+
112+
# error handling
113+
@cheshire_cat_api.exception_handler(RequestValidationError)
114+
async def validation_exception_handler(request, exc):
115+
return JSONResponse(
116+
status_code=400,
117+
content={"error": exc.errors()},
118+
)
119+
120+
121+
# openapi customization
122+
cheshire_cat_api.openapi = get_openapi_configuration_function(cheshire_cat_api)
123+
124+
125+
@cheshire_cat_api.get("/docs", include_in_schema=False)
126+
async def scalar_docs():
127+
return get_scalar_api_reference(
128+
openapi_url=cheshire_cat_api.openapi_url,
129+
title=cheshire_cat_api.title,
130+
scalar_favicon_url="https://cheshirecat.ai/wp-content/uploads/2023/10/Logo-Cheshire-Cat.svg",
131+
)

core/tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import cat.utils as utils
1919
from cat.memory.vector_memory import VectorMemory
2020
from cat.mad_hatter.plugin import Plugin
21-
from cat.main import cheshire_cat_api
21+
from cat.startup import cheshire_cat_api
2222
from tests.utils import create_mock_plugin_zip
2323

2424
import time

0 commit comments

Comments
 (0)