Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch upload of documents #881

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 116 additions & 1 deletion core/cat/routes/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import requests
import io
import json
from typing import Dict
from typing import Dict, List
from copy import deepcopy

from pydantic import BaseModel, Field, ConfigDict
Expand Down Expand Up @@ -36,6 +36,7 @@ def format_upload_file(upload_file: UploadFile) -> UploadFile:
return UploadFile(filename=upload_file.filename, file=io.BytesIO(file_content))



# receive files via http endpoint
@router.post("/")
async def upload_file(
Expand Down Expand Up @@ -129,6 +130,120 @@ async def upload_file(
"info": "File is being ingested asynchronously",
}


# receive files via http endpoint
@router.post("/batch")
async def upload_files(
request: Request,
files: List[UploadFile],
background_tasks: BackgroundTasks,
chunk_size: int | None = Form(
default=None,
description="Maximum length of each chunk after the document is split (in tokens)"
),
chunk_overlap: int | None = Form(
default=None,
description="Chunk overlap (in tokens)"
),
metadata: str = Form(
default="{}",
description="Metadata to be stored where each key is the name of a file being uploaded, and the corresponding value is another dictionary containing metadata specific to that file. "
"Since we are passing this along side form data, metadata must be a JSON string (use `json.dumps(metadata)`)."
),
stray=Depends(HTTPAuth(AuthResource.UPLOAD, AuthPermission.WRITE)),
) -> Dict:
"""Batch upload multiple files containing text (.txt, .md, .pdf, etc.). File content will be extracted and segmented into chunks.
Chunks will be then vectorized and stored into documents memory.

Note
----------
`chunk_size`, `chunk_overlap` anad `metadata` must be passed as form data.
This is necessary because the HTTP protocol does not allow file uploads to be sent as JSON.

Example
----------
```
files = []
files_to_upload = {"sample.pdf":"application/pdf","sample.txt":"application/txt"}

for file_name in files_to_upload:
content_type = files_to_upload[file_name]
file_path = f"tests/mocks/{file_name}"
files.append( ("files", ((file_name, open(file_path, "rb"), content_type))) )


metadata = {
"sample.pdf":{
"source": "sample.pdf",
"title": "Test title",
"author": "Test author",
"year": 2020
},
"sample.txt":{
"source": "sample.txt",
"title": "Test title",
"author": "Test author",
"year": 2021
}
}

# upload file endpoint only accepts form-encoded data
payload = {
"chunk_size": 128,
"metadata": json.dumps(metadata)
}

response = requests.post(
"http://localhost:1865/rabbithole/batch",
files=files,
data=payload
)
```
"""

# Check the file format is supported
admitted_types = stray.rabbit_hole.file_handlers.keys()
log.info(f"Uploading {len(files)} files")

response = {}
metadata_dict = json.loads(metadata)

for file in files:
# Get file mime type
content_type = mimetypes.guess_type(file.filename)[0]
log.info(f"Uploaded {file.filename} {content_type} down the rabbit hole")

# check if MIME type of uploaded file is supported
if content_type not in admitted_types:
raise HTTPException(
status_code=400,
detail={
"error": f'MIME type {content_type} not supported. Admitted types: {" - ".join(admitted_types)}'
},
)

# upload file to long term memory, in the background
background_tasks.add_task(
# we deepcopy the file because FastAPI does not keep the file in memory after the response returns to the client
# https://github.com/tiangolo/fastapi/discussions/10936
stray.rabbit_hole.ingest_file,
stray,
deepcopy(format_upload_file(file)),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
# if file.filename in dictionary pass the metadata otherwise pass empty dictionary
metadata=metadata_dict[file.filename] if file.filename in metadata_dict else {}
)

# reply to client
response[file.filename] = {
"filename": file.filename,
"content_type": file.content_type,
"info": "File is being ingested asynchronously",
}

return response

# This model can be used only for the upload_url endpoint,
# because in uplaod_file we need to pass the file and config as form data
class UploadURLConfig(BaseModel):
Expand Down
94 changes: 94 additions & 0 deletions core/tests/routes/rabbithole/test_upload_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,55 @@ def test_rabbithole_upload_pdf(client):
assert len(declarative_memories) == 4


def test_rabbithole_upload_batch_one_file(client):
content_type = "application/pdf"
file_name = "sample.pdf"
file_path = f"tests/mocks/{file_name}"
with open(file_path, "rb") as f:
files = [ ("files", ((file_name, f, content_type))) ]

response = client.post("/rabbithole/batch", files=files)

# check response
assert response.status_code == 200
json = response.json()
assert len(json) == 1
assert file_name in json
assert json[file_name]["filename"] == file_name
assert json[file_name]["content_type"] == content_type
assert "File is being ingested" in json[file_name]["info"]

# check memory contents
# check declarative memory is empty
declarative_memories = get_declarative_memory_contents(client)
assert len(declarative_memories) == 4

def test_rabbithole_upload_batch_multiple_files(client):
files = []
files_to_upload = {"sample.pdf":"application/pdf","sample.txt":"application/txt"}
for file_name in files_to_upload:
content_type = files_to_upload[file_name]
file_path = f"tests/mocks/{file_name}"
files.append( ("files", ((file_name, open(file_path, "rb"), content_type))) )

response = client.post("/rabbithole/batch", files=files)

# check response
assert response.status_code == 200
json = response.json()
assert len(json) == len(files_to_upload)
for file_name in files_to_upload:
assert file_name in json
assert json[file_name]["filename"] == file_name
assert json[file_name]["content_type"] == files_to_upload[file_name]
assert "File is being ingested" in json[file_name]["info"]

# check memory contents
# check declarative memory is empty
declarative_memories = get_declarative_memory_contents(client)
assert len(declarative_memories) == 7


def test_rabbihole_chunking(client):
content_type = "application/pdf"
file_name = "sample.pdf"
Expand Down Expand Up @@ -101,4 +150,49 @@ def test_rabbithole_upload_doc_with_metadata(client):
assert "when" in dm["metadata"]
assert "source" in dm["metadata"]
print(dm["metadata"])
assert dm["metadata"][k] == v


def test_rabbithole_upload_docs_batch_with_metadata(client):
files = []
files_to_upload = {"sample.pdf":"application/pdf","sample.txt":"application/txt"}
for file_name in files_to_upload:
content_type = files_to_upload[file_name]
file_path = f"tests/mocks/{file_name}"
files.append( ("files", ((file_name, open(file_path, "rb"), content_type))) )

metadata = {
"sample.pdf":{
"source": "sample.pdf",
"title": "Test title",
"author": "Test author",
"year": 2020
},
"sample.txt":{
"source": "sample.txt",
"title": "Test title",
"author": "Test author",
"year": 2021
}
}

# upload file endpoint only accepts form-encoded data
payload = {
"metadata": json.dumps(metadata)
}

response = client.post("/rabbithole/batch", files=files, data=payload)

# check response
assert response.status_code == 200

# check memory contents
declarative_memories = get_declarative_memory_contents(client)
assert len(declarative_memories) == 7
for dm in declarative_memories:
assert "when" in dm["metadata"]
assert "source" in dm["metadata"]
print(dm["metadata"])
# compare with the metadata of the file
for k, v in metadata[dm["metadata"]["source"]].items():
assert dm["metadata"][k] == v
Loading