Skip to content

Commit

Permalink
feat: Describe pictures using vision models (#259)
Browse files Browse the repository at this point in the history
* draft for picture description models

Signed-off-by: Michele Dolfi <[email protected]>

* vlm description using AutoModelForVision2Seq

Signed-off-by: Michele Dolfi <[email protected]>

* add generation options

Signed-off-by: Michele Dolfi <[email protected]>

* update vlm API

Signed-off-by: Michele Dolfi <[email protected]>

* allow only localhost traffic

Signed-off-by: Michele Dolfi <[email protected]>

* rename model

Signed-off-by: Michele Dolfi <[email protected]>

* do not run with vlm api

Signed-off-by: Michele Dolfi <[email protected]>

* more renaming

Signed-off-by: Michele Dolfi <[email protected]>

* fix examples path

Signed-off-by: Michele Dolfi <[email protected]>

* apply CLI download login

Signed-off-by: Michele Dolfi <[email protected]>

* fix name of cli argument

Signed-off-by: Michele Dolfi <[email protected]>

* use with_smolvlm in models download

Signed-off-by: Michele Dolfi <[email protected]>

---------

Signed-off-by: Michele Dolfi <[email protected]>
  • Loading branch information
dolfim-ibm authored Feb 7, 2025
1 parent fba3cf9 commit 4cc6e3e
Show file tree
Hide file tree
Showing 14 changed files with 508 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
run: |
for file in docs/examples/*.py; do
# Skip batch_convert.py
if [[ "$(basename "$file")" =~ ^(batch_convert|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert).py ]]; then
if [[ "$(basename "$file")" =~ ^(batch_convert|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api).py ]]; then
echo "Skipping $file"
continue
fi
Expand Down
5 changes: 5 additions & 0 deletions docling/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ def convert(
help="Enable the picture classification enrichment model in the pipeline.",
),
] = False,
enrich_picture_description: Annotated[
bool,
typer.Option(..., help="Enable the picture description model in the pipeline."),
] = False,
artifacts_path: Annotated[
Optional[Path],
typer.Option(..., help="If provided, the location of the model artifacts."),
Expand Down Expand Up @@ -382,6 +386,7 @@ def convert(
do_table_structure=True,
do_code_enrichment=enrich_code,
do_formula_enrichment=enrich_formula,
do_picture_description=enrich_picture_description,
do_picture_classification=enrich_picture_classes,
document_timeout=document_timeout,
)
Expand Down
2 changes: 2 additions & 0 deletions docling/cli/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class _AvailableModels(str, Enum):
TABLEFORMER = "tableformer"
CODE_FORMULA = "code_formula"
PICTURE_CLASSIFIER = "picture_classifier"
SMOLVLM = "smolvlm"
EASYOCR = "easyocr"


Expand Down Expand Up @@ -81,6 +82,7 @@ def download(
with_tableformer=_AvailableModels.TABLEFORMER in to_download,
with_code_formula=_AvailableModels.CODE_FORMULA in to_download,
with_picture_classifier=_AvailableModels.PICTURE_CLASSIFIER in to_download,
with_smolvlm=_AvailableModels.SMOLVLM in to_download,
with_easyocr=_AvailableModels.EASYOCR in to_download,
)

Expand Down
54 changes: 52 additions & 2 deletions docling/datamodel/pipeline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import os
from enum import Enum
from pathlib import Path
from typing import Any, List, Literal, Optional, Union
from typing import Annotated, Any, Dict, List, Literal, Optional, Union

from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import AnyUrl, BaseModel, ConfigDict, Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -184,6 +184,51 @@ class OcrMacOptions(OcrOptions):
)


class PictureDescriptionBaseOptions(BaseModel):
kind: str
batch_size: int = 8
scale: float = 2

bitmap_area_threshold: float = (
0.2 # percentage of the area for a bitmap to processed with the models
)


class PictureDescriptionApiOptions(PictureDescriptionBaseOptions):
kind: Literal["api"] = "api"

url: AnyUrl = AnyUrl("http://localhost:8000/v1/chat/completions")
headers: Dict[str, str] = {}
params: Dict[str, Any] = {}
timeout: float = 20

prompt: str = "Describe this image in a few sentences."
provenance: str = ""


class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions):
kind: Literal["vlm"] = "vlm"

repo_id: str
prompt: str = "Describe this image in a few sentences."
# Config from here https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig
generation_config: Dict[str, Any] = dict(max_new_tokens=200, do_sample=False)

@property
def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--")


smolvlm_picture_description = PictureDescriptionVlmOptions(
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct"
)
# phi_picture_description = PictureDescriptionVlmOptions(repo_id="microsoft/Phi-3-vision-128k-instruct")
granite_picture_description = PictureDescriptionVlmOptions(
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
prompt="What is shown in this image?",
)


# Define an enum for the backend options
class PdfBackend(str, Enum):
"""Enum of valid PDF backends."""
Expand Down Expand Up @@ -223,6 +268,7 @@ class PdfPipelineOptions(PipelineOptions):
do_code_enrichment: bool = False # True: perform code OCR
do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code
do_picture_classification: bool = False # True: classify pictures in documents
do_picture_description: bool = False # True: run describe pictures in documents

table_structure_options: TableStructureOptions = TableStructureOptions()
ocr_options: Union[
Expand All @@ -232,6 +278,10 @@ class PdfPipelineOptions(PipelineOptions):
OcrMacOptions,
RapidOcrOptions,
] = Field(EasyOcrOptions(), discriminator="kind")
picture_description_options: Annotated[
Union[PictureDescriptionApiOptions, PictureDescriptionVlmOptions],
Field(discriminator="kind"),
] = smolvlm_picture_description

images_scale: float = 1.0
generate_page_images: bool = False
Expand Down
4 changes: 2 additions & 2 deletions docling/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, Iterable, Optional

from docling_core.types.doc import BoundingBox, DoclingDocument, NodeItem, TextItem
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
from typing_extensions import TypeVar

from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
Expand Down Expand Up @@ -64,7 +64,7 @@ def prepare_element(
if not self.is_processable(doc=conv_res.document, element=element):
return None

assert isinstance(element, TextItem)
assert isinstance(element, DocItem)
element_prov = element.prov[0]

bbox = element_prov.bbox
Expand Down
105 changes: 105 additions & 0 deletions docling/models/picture_description_api_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import base64
import io
import logging
from typing import Iterable, List, Optional

import httpx
from docling_core.types.doc import PictureItem
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
PictureDescriptionData,
)
from PIL import Image
from pydantic import BaseModel, ConfigDict

from docling.datamodel.pipeline_options import PictureDescriptionApiOptions
from docling.models.picture_description_base_model import PictureDescriptionBaseModel

_log = logging.getLogger(__name__)


class ChatMessage(BaseModel):
role: str
content: str


class ResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: str


class ResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int


class ApiResponse(BaseModel):
model_config = ConfigDict(
protected_namespaces=(),
)

id: str
model: Optional[str] = None # returned by openai
choices: List[ResponseChoice]
created: int
usage: ResponseUsage


class PictureDescriptionApiModel(PictureDescriptionBaseModel):
# elements_batch_size = 4

def __init__(self, enabled: bool, options: PictureDescriptionApiOptions):
super().__init__(enabled=enabled, options=options)
self.options: PictureDescriptionApiOptions

if self.enabled:
if options.url.host != "localhost":
raise NotImplementedError(
"The options try to connect to remote APIs which are not yet allowed."
)

def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
# Note: technically we could make a batch request here,
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
for image in images:
img_io = io.BytesIO()
image.save(img_io, "PNG")
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")

messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": self.options.prompt,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_base64}"
},
},
],
}
]

payload = {
"messages": messages,
**self.options.params,
}

r = httpx.post(
str(self.options.url),
headers=self.options.headers,
json=payload,
timeout=self.options.timeout,
)
if not r.is_success:
_log.error(f"Error calling the API. Reponse was {r.text}")
r.raise_for_status()

api_resp = ApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip()
yield generated_text
64 changes: 64 additions & 0 deletions docling/models/picture_description_base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import logging
from pathlib import Path
from typing import Any, Iterable, List, Optional, Union

from docling_core.types.doc import (
DoclingDocument,
NodeItem,
PictureClassificationClass,
PictureItem,
)
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
PictureDescriptionData,
)
from PIL import Image

from docling.datamodel.pipeline_options import PictureDescriptionBaseOptions
from docling.models.base_model import (
BaseItemAndImageEnrichmentModel,
ItemAndImageEnrichmentElement,
)


class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel):
images_scale: float = 2.0

def __init__(
self,
enabled: bool,
options: PictureDescriptionBaseOptions,
):
self.enabled = enabled
self.options = options
self.provenance = "not-implemented"

def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
return self.enabled and isinstance(element, PictureItem)

def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
raise NotImplementedError

def __call__(
self,
doc: DoclingDocument,
element_batch: Iterable[ItemAndImageEnrichmentElement],
) -> Iterable[NodeItem]:
if not self.enabled:
for element in element_batch:
yield element.item
return

images: List[Image.Image] = []
elements: List[PictureItem] = []
for el in element_batch:
assert isinstance(el.item, PictureItem)
elements.append(el.item)
images.append(el.image)

outputs = self._annotate_images(images)

for item, output in zip(elements, outputs):
item.annotations.append(
PictureDescriptionData(text=output, provenance=self.provenance)
)
yield item
Loading

0 comments on commit 4cc6e3e

Please sign in to comment.