Skip to content

Commit

Permalink
[VLM] Merged multi-modal processor for Molmo (vllm-project#12966)
Browse files Browse the repository at this point in the history
Signed-off-by: Linkun Chen <[email protected]>
  • Loading branch information
DarkLight1337 authored and lk-chen committed Mar 5, 2025
1 parent 4034cfb commit 51cea54
Show file tree
Hide file tree
Showing 9 changed files with 750 additions and 498 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ
- * `MolmoForCausalLM`
* Molmo
* T + I
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-72B-0924`, etc.
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc.
* ✅︎
* ✅︎
* ✅︎
Expand Down
2 changes: 1 addition & 1 deletion tests/models/decoder_only/language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
pytest.param(
"THUDM/chatglm3-6b", # ChatGLM (text-only)
"THUDM/chatglm3-6b", # chatglm (text-only)
),
pytest.param(
"meta-llama/Llama-3.2-1B-Instruct", # llama
Expand Down
5 changes: 2 additions & 3 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,11 +404,10 @@
"molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501
prompt_formatter=identity,
max_model_len=4096,
max_num_seqs=2,
image_size_factors=[(),(1.0, 1.0, 1.0)],
patch_hf_runner=model_utils.mlomo_patch_hf_runner,
patch_hf_runner=model_utils.molmo_patch_hf_runner,
postprocess_inputs=model_utils.molmo_post_processor,
),
# Tests for phi3v currently live in another file because of a bug in
Expand Down
98 changes: 21 additions & 77 deletions tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
import types
from pathlib import PosixPath
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import torch
from PIL.Image import Image
Expand All @@ -17,9 +17,7 @@
from vllm.transformers_utils.tokenizer import patch_padding_side
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

from .....conftest import (HfRunner, ImageAsset, PromptAudioInput,
PromptImageInput, PromptVideoInput, _ImageAssets)
from ....utils import TokensTextLogprobs
from .....conftest import HfRunner, ImageAsset, _ImageAssets
from .types import RunnerOutput


Expand Down Expand Up @@ -522,74 +520,7 @@ def _generate(self, *args, **kwargs):
return hf_model


def _generate_greedy_logprobs_limit(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> List[TokensTextLogprobs]:
all_inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)

# Process in batches for inference.
if len(all_inputs):
input_ids_lst = []
images_lst = []
images_input_idx_lst = []
imges_masks_lst = []
for inputs in all_inputs:
input_ids_lst.append(inputs["input_ids"])
images_lst.append(inputs["images"])
images_input_idx_lst.append(inputs["image_input_idx"])
imges_masks_lst.append(inputs["image_masks"])
batch_inputs = {}
batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0)
batch_inputs['images'] = torch.cat(images_lst, dim=0)
batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst,
dim=0)
batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0)

outputs = self.model.generate_from_batch(
batch=self.wrap_device(batch_inputs,
device=self.model.device.type),
generation_config=GenerationConfig(
max_new_tokens=max_tokens,
stop_strings="<|endoftext|>",
do_sample=False,
),
tokenizer=self.tokenizer,
output_hidden_states=True,
return_dict_in_generate=True,
)

all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []

for index in range(len(all_inputs)):
(
seq_logprobs_lst,
output_len,
) = self._hidden_states_to_logprobs(outputs.hidden_states,
num_logprobs)
all_logprobs.append(seq_logprobs_lst)
seq_ids = outputs.sequences[index]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]


####### Molmo-specific HuggingFace runner patchers
def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Molmo."""
hf_processor = hf_model.processor

Expand All @@ -598,10 +529,23 @@ def _processor(*args, **kwargs):

hf_model.processor = _processor

setattr( # noqa: B010
hf_model,
"generate_greedy_logprobs_limit",
types.MethodType(_generate_greedy_logprobs_limit, hf_model),
)
def _generate(self, max_new_tokens=None, do_sample=None, **kwargs):
batch = {
k: kwargs.pop(k)
for k in ("input_ids", "images", "image_input_idx", "image_masks")
if k in kwargs
}

return self.generate_from_batch(
batch,
generation_config=GenerationConfig(
max_new_tokens=max_new_tokens,
stop_strings="<|endoftext|>",
do_sample=do_sample,
),
**kwargs,
)

hf_model.model.generate = types.MethodType(_generate, hf_model.model)

return hf_model
2 changes: 2 additions & 0 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def _test_processing_correctness(
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
Expand Down
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def check_available_online(
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6",
trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501
trust_remote_code=True),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
trust_remote_code=True),
Expand Down
Loading

0 comments on commit 51cea54

Please sign in to comment.