-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Describe pictures using vision models (#259)
* draft for picture description models Signed-off-by: Michele Dolfi <[email protected]> * vlm description using AutoModelForVision2Seq Signed-off-by: Michele Dolfi <[email protected]> * add generation options Signed-off-by: Michele Dolfi <[email protected]> * update vlm API Signed-off-by: Michele Dolfi <[email protected]> * allow only localhost traffic Signed-off-by: Michele Dolfi <[email protected]> * rename model Signed-off-by: Michele Dolfi <[email protected]> * do not run with vlm api Signed-off-by: Michele Dolfi <[email protected]> * more renaming Signed-off-by: Michele Dolfi <[email protected]> * fix examples path Signed-off-by: Michele Dolfi <[email protected]> * apply CLI download login Signed-off-by: Michele Dolfi <[email protected]> * fix name of cli argument Signed-off-by: Michele Dolfi <[email protected]> * use with_smolvlm in models download Signed-off-by: Michele Dolfi <[email protected]> --------- Signed-off-by: Michele Dolfi <[email protected]>
- Loading branch information
1 parent
fba3cf9
commit 4cc6e3e
Showing
14 changed files
with
508 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import base64 | ||
import io | ||
import logging | ||
from typing import Iterable, List, Optional | ||
|
||
import httpx | ||
from docling_core.types.doc import PictureItem | ||
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc | ||
PictureDescriptionData, | ||
) | ||
from PIL import Image | ||
from pydantic import BaseModel, ConfigDict | ||
|
||
from docling.datamodel.pipeline_options import PictureDescriptionApiOptions | ||
from docling.models.picture_description_base_model import PictureDescriptionBaseModel | ||
|
||
_log = logging.getLogger(__name__) | ||
|
||
|
||
class ChatMessage(BaseModel): | ||
role: str | ||
content: str | ||
|
||
|
||
class ResponseChoice(BaseModel): | ||
index: int | ||
message: ChatMessage | ||
finish_reason: str | ||
|
||
|
||
class ResponseUsage(BaseModel): | ||
prompt_tokens: int | ||
completion_tokens: int | ||
total_tokens: int | ||
|
||
|
||
class ApiResponse(BaseModel): | ||
model_config = ConfigDict( | ||
protected_namespaces=(), | ||
) | ||
|
||
id: str | ||
model: Optional[str] = None # returned by openai | ||
choices: List[ResponseChoice] | ||
created: int | ||
usage: ResponseUsage | ||
|
||
|
||
class PictureDescriptionApiModel(PictureDescriptionBaseModel): | ||
# elements_batch_size = 4 | ||
|
||
def __init__(self, enabled: bool, options: PictureDescriptionApiOptions): | ||
super().__init__(enabled=enabled, options=options) | ||
self.options: PictureDescriptionApiOptions | ||
|
||
if self.enabled: | ||
if options.url.host != "localhost": | ||
raise NotImplementedError( | ||
"The options try to connect to remote APIs which are not yet allowed." | ||
) | ||
|
||
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: | ||
# Note: technically we could make a batch request here, | ||
# but not all APIs will allow for it. For example, vllm won't allow more than 1. | ||
for image in images: | ||
img_io = io.BytesIO() | ||
image.save(img_io, "PNG") | ||
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") | ||
|
||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": self.options.prompt, | ||
}, | ||
{ | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/png;base64,{image_base64}" | ||
}, | ||
}, | ||
], | ||
} | ||
] | ||
|
||
payload = { | ||
"messages": messages, | ||
**self.options.params, | ||
} | ||
|
||
r = httpx.post( | ||
str(self.options.url), | ||
headers=self.options.headers, | ||
json=payload, | ||
timeout=self.options.timeout, | ||
) | ||
if not r.is_success: | ||
_log.error(f"Error calling the API. Reponse was {r.text}") | ||
r.raise_for_status() | ||
|
||
api_resp = ApiResponse.model_validate_json(r.text) | ||
generated_text = api_resp.choices[0].message.content.strip() | ||
yield generated_text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import logging | ||
from pathlib import Path | ||
from typing import Any, Iterable, List, Optional, Union | ||
|
||
from docling_core.types.doc import ( | ||
DoclingDocument, | ||
NodeItem, | ||
PictureClassificationClass, | ||
PictureItem, | ||
) | ||
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc | ||
PictureDescriptionData, | ||
) | ||
from PIL import Image | ||
|
||
from docling.datamodel.pipeline_options import PictureDescriptionBaseOptions | ||
from docling.models.base_model import ( | ||
BaseItemAndImageEnrichmentModel, | ||
ItemAndImageEnrichmentElement, | ||
) | ||
|
||
|
||
class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel): | ||
images_scale: float = 2.0 | ||
|
||
def __init__( | ||
self, | ||
enabled: bool, | ||
options: PictureDescriptionBaseOptions, | ||
): | ||
self.enabled = enabled | ||
self.options = options | ||
self.provenance = "not-implemented" | ||
|
||
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: | ||
return self.enabled and isinstance(element, PictureItem) | ||
|
||
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: | ||
raise NotImplementedError | ||
|
||
def __call__( | ||
self, | ||
doc: DoclingDocument, | ||
element_batch: Iterable[ItemAndImageEnrichmentElement], | ||
) -> Iterable[NodeItem]: | ||
if not self.enabled: | ||
for element in element_batch: | ||
yield element.item | ||
return | ||
|
||
images: List[Image.Image] = [] | ||
elements: List[PictureItem] = [] | ||
for el in element_batch: | ||
assert isinstance(el.item, PictureItem) | ||
elements.append(el.item) | ||
images.append(el.image) | ||
|
||
outputs = self._annotate_images(images) | ||
|
||
for item, output in zip(elements, outputs): | ||
item.annotations.append( | ||
PictureDescriptionData(text=output, provenance=self.provenance) | ||
) | ||
yield item |
Oops, something went wrong.