diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 55b3f52356cd0..86b7461782283 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ - * `MolmoForCausalLM` * Molmo * T + I - * `allenai/Molmo-7B-D-0924`, `allenai/Molmo-72B-0924`, etc. + * `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. * ✅︎ * ✅︎ * ✅︎ diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index c6d5244318a32..71e4a9f11ab82 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -27,7 +27,7 @@ marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), pytest.param( - "THUDM/chatglm3-6b", # ChatGLM (text-only) + "THUDM/chatglm3-6b", # chatglm (text-only) ), pytest.param( "meta-llama/Llama-3.2-1B-Instruct", # llama diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index b00ec6fa69995..4ed61cfc9b7c0 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -404,11 +404,10 @@ "molmo": VLMTestInfo( models=["allenai/Molmo-7B-D-0924"], test_type=(VLMTestType.IMAGE), - prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501 + prompt_formatter=identity, max_model_len=4096, max_num_seqs=2, - image_size_factors=[(),(1.0, 1.0, 1.0)], - patch_hf_runner=model_utils.mlomo_patch_hf_runner, + patch_hf_runner=model_utils.molmo_patch_hf_runner, postprocess_inputs=model_utils.molmo_post_processor, ), # Tests for phi3v currently live in another file because of a bug in diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index ced891e1e2c20..408ce9cfeadab 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -6,7 +6,7 @@ import re import types from pathlib import PosixPath -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from PIL.Image import Image @@ -17,9 +17,7 @@ from vllm.transformers_utils.tokenizer import patch_padding_side from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from .....conftest import (HfRunner, ImageAsset, PromptAudioInput, - PromptImageInput, PromptVideoInput, _ImageAssets) -from ....utils import TokensTextLogprobs +from .....conftest import HfRunner, ImageAsset, _ImageAssets from .types import RunnerOutput @@ -522,74 +520,7 @@ def _generate(self, *args, **kwargs): return hf_model -def _generate_greedy_logprobs_limit( - self, - prompts: List[str], - max_tokens: int, - num_logprobs: int, - images: Optional[PromptImageInput] = None, - audios: Optional[PromptAudioInput] = None, - videos: Optional[PromptVideoInput] = None, - **kwargs: Any, -) -> List[TokensTextLogprobs]: - all_inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) - - # Process in batches for inference. - if len(all_inputs): - input_ids_lst = [] - images_lst = [] - images_input_idx_lst = [] - imges_masks_lst = [] - for inputs in all_inputs: - input_ids_lst.append(inputs["input_ids"]) - images_lst.append(inputs["images"]) - images_input_idx_lst.append(inputs["image_input_idx"]) - imges_masks_lst.append(inputs["image_masks"]) - batch_inputs = {} - batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0) - batch_inputs['images'] = torch.cat(images_lst, dim=0) - batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst, - dim=0) - batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0) - - outputs = self.model.generate_from_batch( - batch=self.wrap_device(batch_inputs, - device=self.model.device.type), - generation_config=GenerationConfig( - max_new_tokens=max_tokens, - stop_strings="<|endoftext|>", - do_sample=False, - ), - tokenizer=self.tokenizer, - output_hidden_states=True, - return_dict_in_generate=True, - ) - - all_logprobs: List[List[Dict[int, float]]] = [] - all_output_ids: List[List[int]] = [] - all_output_strs: List[str] = [] - - for index in range(len(all_inputs)): - ( - seq_logprobs_lst, - output_len, - ) = self._hidden_states_to_logprobs(outputs.hidden_states, - num_logprobs) - all_logprobs.append(seq_logprobs_lst) - seq_ids = outputs.sequences[index] - output_ids = seq_ids[-output_len:] - all_output_ids.append(output_ids.tolist()) - all_output_strs.append(self.tokenizer.decode(output_ids)) - outputs = zip(all_output_ids, all_output_strs, all_logprobs) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] - - -####### Molmo-specific HuggingFace runner patchers -def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: +def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for Molmo.""" hf_processor = hf_model.processor @@ -598,10 +529,23 @@ def _processor(*args, **kwargs): hf_model.processor = _processor - setattr( # noqa: B010 - hf_model, - "generate_greedy_logprobs_limit", - types.MethodType(_generate_greedy_logprobs_limit, hf_model), - ) + def _generate(self, max_new_tokens=None, do_sample=None, **kwargs): + batch = { + k: kwargs.pop(k) + for k in ("input_ids", "images", "image_input_idx", "image_masks") + if k in kwargs + } + + return self.generate_from_batch( + batch, + generation_config=GenerationConfig( + max_new_tokens=max_new_tokens, + stop_strings="<|endoftext|>", + do_sample=do_sample, + ), + **kwargs, + ) + + hf_model.model.generate = types.MethodType(_generate, hf_model.model) return hf_model diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 67ef8b17ab8c1..88dcc32f44f52 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -168,6 +168,8 @@ def _test_processing_correctness( "mistral-community/pixtral-12b", "openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-V-2_6", + "allenai/Molmo-7B-D-0924", + "allenai/Molmo-7B-O-0924", "nvidia/NVLM-D-72B", "Qwen/Qwen-VL-Chat", "Qwen/Qwen2-VL-2B-Instruct", diff --git a/tests/models/registry.py b/tests/models/registry.py index 7b1db55494fe4..66a487ca60e90 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -256,6 +256,7 @@ def check_available_online( "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6", trust_remote_code=True), "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", + extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501 trust_remote_code=True), "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", trust_remote_code=True), diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index b524a14977b16..feb5850223178 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1,18 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 import math -import re -from array import array from dataclasses import dataclass -from functools import lru_cache, partial -from typing import Iterable, List, Mapping, Optional, Set, Tuple, TypedDict +from functools import cached_property, partial +from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, + Union, cast) +import numpy as np import torch +import torch.nn as nn +import torch.nn.functional as F from einops import rearrange -from PIL import Image -from torch import nn -from torch.nn import functional as F -from transformers import PretrainedConfig +from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin, + TensorType) +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import TextInput from vllm.attention import Attention, AttentionMetadata from vllm.attention.layer import MultiHeadAttention @@ -22,8 +24,6 @@ get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather) -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU, SiluAndMul) @@ -40,15 +40,21 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import NestedTensors, PlaceholderRange -from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) -from vllm.transformers_utils.processor import get_processor +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptReplacementDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.utils import JSONTree, json_map_leaves from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) @@ -56,38 +62,39 @@ VIT_LAYERS = [-2, -9] NUM_PREFIX_TOKENS = 1 ADDITIONAL_VOCAB_SIZE = 128 -DEFAULT_IMAGE_PATCH_TOKEN_ID = 152066 -DEFAULT_IM_START_TOKEN_ID = 152067 -DEFAULT_IM_END_TOKEN_ID = 152064 -DEFAULT_IM_COL_TOKEN_ID = 152065 +IMAGE_PATCH_TOKEN = "" +IM_COL_TOKEN = "" +IM_START_TOKEN = "" +IM_END_TOKEN = "" +POOLING_SIZE = 2 class MolmoImageInputs(TypedDict): - images: torch.Tensor - """Shape: - `(batch_size, num_crops, num_patch, patch_dim)` - """ + images: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size, num_crops, num_patch, patch_dim)`""" - image_input_idx: torch.Tensor - """Shape: - `(batch_size, num_crops, num_patch)` - """ + image_masks: Optional[Union[torch.Tensor, List[torch.Tensor]]] + """Shape: `(batch_size, num_crops, num_patch)`""" - seq_len: torch.Tensor - """Shape: - `(batch_size, )` + feat_is_patch: Union[torch.Tensor, List[torch.Tensor]] """ + A boolean mask indicating which image features correspond + to patch tokens. - image_masks: Optional[torch.Tensor] - """Shape: - `(batch_size, num_crops, num_patch)` + Shape: `(batch_size, num_crops, num_patch)` """ - image_start_end: Tuple[int, int] - """Starting and ending index of placeholder - tokens + embed_is_patch: Union[torch.Tensor, List[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size, num_embeds)` """ + num_crops: torch.Tensor + """Shape: `(batch_size, num_images)`""" + @dataclass class VisionBackboneConfig: @@ -335,7 +342,7 @@ def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: def forward(self, x: torch.Tensor, - patch_num: int = None) -> List[torch.Tensor]: + patch_num: Optional[int] = None) -> List[torch.Tensor]: """ : param x: (batch_size, num_patch, n_pixels) """ @@ -465,7 +472,7 @@ def forward( return output -class LanuageModelMLP(nn.Module): +class LanguageModelMLP(nn.Module): """Molmo's LLM mlp.""" def __init__(self, @@ -559,7 +566,7 @@ def __init__( prefix=f"{prefix}.self_attn") # MLP block. - self.mlp = LanuageModelMLP(config, quant_config=quant_config) + self.mlp = LanguageModelMLP(config, quant_config=quant_config) # LayerNorm assert config.layer_norm_type == "rms" @@ -638,8 +645,8 @@ def __init__( self.vit_layers = VIT_LAYERS self.image_num_patch = vision_config.image_num_patch self.llm_patches_per_crop = ( - (self.image_num_patch[0] + 1) // 2, - (self.image_num_patch[1] + 1) // 2, + (self.image_num_patch[0] + 1) // POOLING_SIZE, + (self.image_num_patch[1] + 1) // POOLING_SIZE, ) self.image_vit = VisionTransformer(vision_config, quant_config=quant_config) @@ -723,19 +730,19 @@ def forward( image_features = image_features.reshape( (batch_size, num_image) + self.image_num_patch + (-1, ), ) - if self.image_num_patch[0] % 2 == 1: - # Pad so we can still pool 2x2 patches + if (missing_w := self.image_num_patch[0] % POOLING_SIZE): + # Padding for image pooling (see below) image_features = F.pad( image_features, - (0, 0, 0, 1, 0, 1, 0, 0, 0, 0), + (0, 0, 0, missing_w, 0, missing_w, 0, 0, 0, 0), ) # image pooling image_features = rearrange( image_features, 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c', - dh=2, - dw=2, + dh=POOLING_SIZE, + dw=POOLING_SIZE, ) query = image_features.mean(-2, keepdim=True) @@ -888,249 +895,513 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -cached_get_processor = lru_cache(get_processor) +def _lowest_multiple(x: int, k: int) -> int: + return (x // k) * k + +def get_num_patches( + num_tiles: int, + *, + crop_patches: int, + left_margin: int, + right_margin: int, + pooling_size: int, +) -> int: + if num_tiles == 1: + return _lowest_multiple(crop_patches + pooling_size - 1, pooling_size) -def get_num_patches(num_tiles: int, crop_patches: int, left_margin: int, - right_margin: int, pooling_size: int) -> int: crop_window_patches = crop_patches - (left_margin + right_margin) - if num_tiles > 1: - left_crop_window_patches = (crop_window_patches + left_margin + - pooling_size - - 1) // pooling_size * pooling_size - middle_crop_window_patches = (crop_window_patches + pooling_size - - 1) // pooling_size * pooling_size - right_crop_window_patches = (crop_window_patches + right_margin + - pooling_size - - 1) // pooling_size * pooling_size - return left_crop_window_patches + ( - num_tiles - - 2) * middle_crop_window_patches + right_crop_window_patches - else: - single_crop_window_patches = (crop_patches + pooling_size - - 1) // pooling_size * pooling_size - return single_crop_window_patches - - -def get_tokens(tiling_h: int, tiling_w: int, crop_patches: int, - left_margin: int, right_margin: int, pooling_size: int) -> int: - h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin, - pooling_size) - w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin, - pooling_size) - per_row = w // pooling_size + 1 - joint = per_row * (h // pooling_size) + 2 - image_token_length = (crop_patches + pooling_size - 1) // pooling_size - resize = (image_token_length + 1) * image_token_length + 2 - return resize + joint - - -def get_max_tokens(max_crops: int, crop_patches: int, left_margin: int, - right_margin: int, pooling_size: int) -> int: - tilings = [] - for i in range(1, max_crops + 1): - for j in range(1, max_crops + 1): - if i * j <= max_crops: - tilings.append((i, j)) - tokens = [ - get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin, - right_margin, pooling_size) for i in range(len(tilings)) - ] - return max(tokens) - - -def get_max_molmo_image_tokens(ctx: InputContext) -> int: - processor = cached_get_processor( - ctx.model_config.model, - trust_remote_code=ctx.model_config.trust_remote_code, - revision=ctx.model_config.code_revision) - image_processor = processor.image_processor - max_llm_image_tokens = get_max_tokens( - image_processor.max_crops, - image_processor.base_image_input_size[0] // - image_processor.image_patch_size, - image_processor.overlap_margins[0], - image_processor.overlap_margins[1], - 2, + + left_num = _lowest_multiple( + crop_window_patches + left_margin + pooling_size - 1, + pooling_size, + ) + middle_num = _lowest_multiple( + crop_window_patches + pooling_size - 1, + pooling_size, + ) + right_num = _lowest_multiple( + crop_window_patches + right_margin + pooling_size - 1, + pooling_size, ) - return max_llm_image_tokens + return left_num + (num_tiles - 2) * middle_num + right_num + + +def get_patches_grid_size( + *, + tiling_h: int, + tiling_w: int, + crop_patches: int, + left_margin: int, + right_margin: int, + pooling_size: int, +) -> tuple[int, int]: + nrows = get_num_patches( + tiling_h, + crop_patches=crop_patches, + left_margin=left_margin, + right_margin=right_margin, + pooling_size=pooling_size, + ) + ncols = get_num_patches( + tiling_w, + crop_patches=crop_patches, + left_margin=left_margin, + right_margin=right_margin, + pooling_size=pooling_size, + ) -# NOTE: preprocessing for the image data has been included in the -# 'input_processor_for_molmo' function -def image_input_mapper_for_molmo( - ctx: InputContext, - data: object, -): - if isinstance(data, list): - assert len(data) == 1, "Molmo supports only one image per prompt." - data = data[0] - - return MultiModalKwargs(data) - - -def dummy_data_for_molmo(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - processor = cached_get_processor( - ctx.model_config.model, - trust_remote_code=ctx.model_config.trust_remote_code, - revision=ctx.model_config.code_revision) - image_processor = processor.image_processor - - base_image_input_d = image_processor.image_patch_size - left_margin, right_margin = image_processor.overlap_margins - max_crops = image_processor.max_crops - - # Assume: prompt_token_ids always starts with bos_token_id followed image tokens # noqa: E501 - max_llm_image_tokens = get_max_molmo_image_tokens(ctx) - if seq_len - max_llm_image_tokens - 1 < 0: - raise RuntimeError( - f"Molmo cannot process {max_crops} crops in a prompt, " - "please increase max_model_len or reduce number of crops") - - # The vertical image has the maximum number of image tokens due to column tokens. # noqa: E501 - tiling = (max_crops, 1) - total_margin_pixels = base_image_input_d * (right_margin + left_margin) - crop_patches = image_processor.base_image_input_size[ - 0] // base_image_input_d - crop_window_patches = crop_patches - (right_margin + left_margin) - crop_window_size = crop_window_patches * base_image_input_d - - h = crop_window_size * tiling[0] + total_margin_pixels - w = crop_window_size * tiling[1] + total_margin_pixels - - dummy_image = Image.new("RGB", (w, h), color="red") - - out = processor.process("dummy prompt", dummy_image) - - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - out["input_ids"][:1 + max_llm_image_tokens]) - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - max_llm_image_tokens - 1) - dummy_seqdata = SequenceData(token_ids) - dummy_imgdata = { - "images": out["images"], - "image_input_idx": out["image_input_idx"], - } - if "image_masks" in out: - dummy_imgdata["image_masks"] = out["image_masks"] - dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long) - size = 0 - offset = -1 - for i in range(len(token_ids)): - if token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, - DEFAULT_IM_START_TOKEN_ID, DEFAULT_IM_END_TOKEN_ID, - DEFAULT_IM_COL_TOKEN_ID): - if offset < 0: - offset = i - size += 1 - dummy_imgdata["image_start_end"] = (offset, offset + size) - return DummyData(seq_data=dummy_seqdata, - multi_modal_data={"image": dummy_imgdata}, - multi_modal_placeholders={ - "image": - [PlaceholderRange(offset=offset, length=size)] - }) - - -def pad_images( - max_total_crops: int, - images: torch.Tensor, - image_input_idx: torch.Tensor, - image_masks: Optional[torch.Tensor] = None, + return nrows, ncols + + +def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]: + tilings = [(i, j) for i in range(1, max_num + 1) + for j in range(1, max_num + 1) if i * j <= max_num] + return sorted(tilings, key=lambda x: x[0] * x[1]) + + +def select_tiling( + *, + height: int, + width: int, + patch_size: int, + max_num_patches: int, ): - n = max_total_crops - images.shape[0] - images = F.pad(images, (0, 0, 0, 0, 0, n), value=-1) - image_input_idx = F.pad(image_input_idx, (0, 0, 0, n), value=-1) - if image_masks is not None: - image_masks = F.pad(image_masks, (0, 0, 0, n), value=-1) - return images, image_input_idx, image_masks - - -def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): - prompt = inputs.get("prompt") - multi_modal_data = inputs.get("multi_modal_data") - image = None if multi_modal_data is None else multi_modal_data.get("image") - - model_config = ctx.model_config - processor = cached_get_processor( - ctx.model_config.model, - trust_remote_code=model_config.trust_remote_code, - revision=ctx.model_config.code_revision) - tokenizer = cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code) - - # NOTE: message formatting for raw text prompt is only applied for - # offline inference; for online serving, the prompt is always in - # instruction format and tokenized. - if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$", - prompt): - out = processor.process(prompt, image, message_format="none") - elif prompt is not None: - out = processor.process(prompt, image) + tilings = get_candidate_tilings(max_num_patches) + candidate_tilings = np.array(tilings, dtype=np.int32) + candidate_resolutions = candidate_tilings * patch_size + + original_size = np.array([height, width], dtype=np.float32) + required_scale_d = candidate_resolutions.astype(np.float32) / original_size + required_scale = required_scale_d.min(axis=-1, keepdims=True) + + if (required_scale < 1).all(): + ix = required_scale.argmax() else: - out = processor.process(None, image, tokens=inputs["prompt_token_ids"]) - - # If there is no image, return directly. - if image is None: - new_prompt_token_ids = out["input_ids"].tolist() - prompt = inputs.get("prompt") - if prompt is None: - prompt = tokenizer.decode(new_prompt_token_ids) - return token_inputs( - prompt_token_ids=new_prompt_token_ids, - prompt=prompt, + ix = np.where(required_scale < 1.0, 10e9, required_scale).argmin() + + return candidate_tilings[ix] + + +class MolmoProcessorWrapper: + """ + Wraps :class:`MolmoProcessor` so that it can be called directly. + + The original definition can be found here: + https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py + """ + + def __init__(self, processor: ProcessorMixin): + super().__init__() + + self.processor = processor + + @cached_property + def vocab(self) -> dict[str, int]: + return self.processor.tokenizer.vocab # type: ignore + + @cached_property + def max_crops(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + max_crops = image_processor.max_crops + assert isinstance(max_crops, int) + + return max_crops + + @cached_property + def base_image_input_size(self) -> tuple[int, int]: + image_processor = self.processor.image_processor # type: ignore + + base_image_input_size = image_processor.base_image_input_size + if isinstance(base_image_input_size, int): + return base_image_input_size, base_image_input_size + + return tuple(base_image_input_size) + + @cached_property + def image_patch_size(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + image_patch_size = image_processor.image_patch_size + assert isinstance(image_patch_size, int) + + return image_patch_size + + @cached_property + def overlap_margins(self) -> tuple[int, int]: + image_processor = self.processor.image_processor # type: ignore + + left_margin, right_margin = image_processor.overlap_margins + assert isinstance(left_margin, int) + assert isinstance(right_margin, int) + + return left_margin, right_margin + + @cached_property + def image_token_length_w(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + image_token_length_w = image_processor.image_token_length_w + assert isinstance(image_token_length_w, int) + + return image_token_length_w + + @cached_property + def image_token_length_h(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + image_token_length_h = image_processor.image_token_length_h + assert isinstance(image_token_length_h, int) + + return image_token_length_h + + @property + def message_format(self) -> Optional[str]: + return "role" + + @property + def always_start_with_space(self) -> bool: + return True + + @cached_property + def image_patch_id(self) -> int: + return self.vocab[IMAGE_PATCH_TOKEN] + + @cached_property + def im_col_id(self) -> int: + return self.vocab[IM_COL_TOKEN] + + @cached_property + def im_start_id(self) -> int: + return self.vocab[IM_START_TOKEN] + + @cached_property + def im_end_id(self) -> int: + return self.vocab[IM_END_TOKEN] + + @property + def pooling_size(self) -> int: + return POOLING_SIZE + + def select_tiling( + self, + *, + image_width: int, + image_height: int, + ) -> tuple[int, int]: + max_crops = self.max_crops + left_margin, right_margin = self.overlap_margins + base_image_input_size = self.base_image_input_size + base_image_input_d = self.image_patch_size + + total_margin_pixels = base_image_input_d * (right_margin + left_margin) + crop_patches = base_image_input_size[0] // base_image_input_d + crop_window_patches = crop_patches - (right_margin + left_margin) + crop_window_size = crop_window_patches * base_image_input_d + tiling_h, tiling_w = select_tiling( + height=image_height - total_margin_pixels, + width=image_width - total_margin_pixels, + patch_size=crop_window_size, + max_num_patches=max_crops, ) - image_processor = processor.image_processor - max_total_crops = 1 + image_processor.max_crops - images, image_input_idx, image_masks = pad_images( - max_total_crops, - out["images"], - out["image_input_idx"], - out.get("image_masks"), - ) - image_data = dict( - images=images, - image_input_idx=image_input_idx, - ) - if image_masks is not None: - image_data["image_masks"] = image_masks - - new_prompt_token_ids = out["input_ids"].tolist() - image_data["seq_len"] = torch.tensor(len(new_prompt_token_ids), - dtype=torch.long) - - multi_modal_data = dict(image=image_data) - size = 0 - offset = -1 - for i in range(len(new_prompt_token_ids)): - if new_prompt_token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, - DEFAULT_IM_START_TOKEN_ID, - DEFAULT_IM_END_TOKEN_ID, - DEFAULT_IM_COL_TOKEN_ID): - if offset < 0: - offset = i - size += 1 - image_data["image_start_end"] = (offset, offset + size) - prompt = inputs.get("prompt") - if prompt is None: - prompt = tokenizer.decode(new_prompt_token_ids) - return token_inputs( - prompt_token_ids=new_prompt_token_ids, - prompt=prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={ - "image": [PlaceholderRange(offset=offset, length=size)] - }, - ) + return tiling_w, tiling_h + + def get_patches_grid_size( + self, + *, + image_width: int, + image_height: int, + ) -> tuple[int, int]: + left_margin, right_margin = self.overlap_margins + base_image_input_size = self.base_image_input_size + base_image_input_d = self.image_patch_size + pooling_size = self.pooling_size + + crop_patches = base_image_input_size[0] // base_image_input_d + tiling_w, tiling_h = self.select_tiling( + image_height=image_height, + image_width=image_width, + ) + + nrows, ncols = get_patches_grid_size( + tiling_h=tiling_h, + tiling_w=tiling_w, + crop_patches=crop_patches, + left_margin=left_margin, + right_margin=right_margin, + pooling_size=pooling_size, + ) + + return ncols, nrows + + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + images: Optional[Union[ImageInput, list[ImageInput]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + outputs = self.processor.process( # type: ignore + text, images, **kwargs) + + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + input_ids: torch.Tensor = outputs.pop("input_ids") + outputs["input_ids"] = input_ids.unsqueeze(0) + + image_input_idx = outputs.pop("image_input_idx", None) + if image_input_idx is not None: + input_is_patch = input_ids == self.image_patch_id + image_input_idx_flat: torch.Tensor = image_input_idx.view(-1) + image_valid_flat = image_input_idx_flat >= 0 + feat_is_patch_flat = image_valid_flat.clone() + feat_is_patch_flat[image_valid_flat] = ( + input_is_patch[image_input_idx_flat[image_valid_flat]]) + feat_is_patch = feat_is_patch_flat.view(*image_input_idx.shape) + + input_is_embed = torch.isin( + input_ids, + torch.tensor([ + self.image_patch_id, + self.im_col_id, + self.im_start_id, + self.im_end_id, + ]), + ) + embed_ids = input_ids[input_is_embed] + embed_is_patch = embed_ids == self.image_patch_id + assert embed_is_patch.sum() == feat_is_patch.sum() + tilings = [ + self.select_tiling( + image_width=image.size[0], + image_height=image.size[1], + ) for image in images + ] + # For each image: tiling_h * tiling_w + extra + num_crops = torch.tensor(tilings).prod(-1) + 1 + assert num_crops.sum() == len(feat_is_patch) -@MULTIMODAL_REGISTRY.register_image_input_mapper(image_input_mapper_for_molmo) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) -@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) + outputs["feat_is_patch"] = feat_is_patch + outputs["embed_is_patch"] = embed_is_patch + outputs["num_crops"] = num_crops + outputs["img_patch_id"] = self.image_patch_id + + return BatchFeature(outputs, tensor_type=return_tensors) + + +class MolmoProcessingInfo(BaseProcessingInfo): + + def get_hf_processor(self) -> MolmoProcessorWrapper: + processor = self.ctx.get_hf_processor() + return MolmoProcessorWrapper(processor) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[MolmoProcessorWrapper], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + ncols, nrows = processor.get_patches_grid_size( + image_width=image_width, + image_height=image_height, + ) + pooling_size = processor.pooling_size + + base_image_input_size = processor.base_image_input_size + base_image_input_d = processor.image_patch_size + + crop_patches = base_image_input_size[0] // base_image_input_d + + per_row = ncols // pooling_size + 1 + joint = per_row * (nrows // pooling_size) + 2 + image_token_length = (crop_patches + pooling_size - 1) // pooling_size + resize = (image_token_length + 1) * image_token_length + 2 + + return resize + joint + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + processor=None, + ) + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_hf_processor() + + tilings = get_candidate_tilings(processor.max_crops) + base_h, base_w = processor.base_image_input_size + + largest_feature_size, largest_feature_pinpoint = 0, None + for wr, hr in tilings: + width, height = base_w * wr, base_h * hr + + feat_size = self.get_num_image_tokens( + image_width=width, + image_height=height, + processor=processor, + ) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, + height=height) + + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") + + return largest_feature_pinpoint + + +class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + target_width, target_height = \ + self.info.get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + +class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + processor = self.info.get_hf_processor() + + # Apply the chat template to the tokens + tokens = processor.processor.get_tokens_input( # type: ignore + self.info.get_tokenizer().decode(prompt_tokens), + message_format=processor.message_format, + always_start_with_space=processor.always_start_with_space, + ) + + processed_data = self.info.ctx.call_hf_processor( + processor, # type: ignore + dict(tokens=tokens), + ) + prompt_ids, = processed_data.pop("input_ids").tolist() + + return prompt_ids + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_crops = hf_inputs.get("num_crops", torch.empty(0)) + num_images = len(num_crops) + + return dict( + images=MultiModalFieldConfig.flat_from_sizes("image", num_crops), + image_masks=MultiModalFieldConfig.flat_from_sizes( + "image", num_crops), + feat_is_patch=MultiModalFieldConfig.flat_from_sizes( + "image", num_crops), + embed_is_patch=MultiModalFieldConfig.shared("image", num_images), + num_crops=MultiModalFieldConfig.batched("image"), + img_patch_id=MultiModalFieldConfig.shared("image", num_images), + ) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + + image_token_length_w = processor.image_token_length_w + image_token_length_h = processor.image_token_length_h + pooling_size = processor.pooling_size + + user_str = "User:" + if processor.always_start_with_space: + user_str = " " + user_str + + user_tokens = tokenizer.encode(user_str, add_special_tokens=False) + + img_patch_id = processor.image_patch_id + img_col_id = processor.im_col_id + img_start_id = processor.im_start_id + img_end_id = processor.im_end_id + + extra_row = [img_patch_id] * image_token_length_w + [img_col_id] + extra_joint = ([img_start_id] + extra_row * image_token_length_h + + [img_end_id]) + + def get_replacement_molmo(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = processor.get_patches_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) + + joint_row = ([img_patch_id] * ((ncols + 1) // pooling_size) + + [img_col_id]) + joint = ([img_start_id] + joint_row * + ((nrows + 1) // pooling_size) + [img_end_id]) + + image_tokens = extra_joint + joint + + return PromptReplacementDetails( + full=image_tokens + user_tokens, + features=image_tokens, + ) + + return [ + PromptReplacement( + modality="image", + target=user_str, + replacement=get_replacement_molmo, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(MolmoMultiModalProcessor, + info=MolmoProcessingInfo, + dummy_inputs=MolmoDummyInputsBuilder) class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): hf_to_vllm_mapper = WeightsMapper( @@ -1202,6 +1473,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config) self.model = MolmoModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.img_patch_id = None if self.config.weight_tying: self.lm_head = self.model.transformer.wte @@ -1224,85 +1496,143 @@ def _parse_and_validate_image_input( **kwargs: object, ) -> Optional[MolmoImageInputs]: images = kwargs.pop("images", None) - image_masks = kwargs.pop("image_masks", None) - image_start_end = kwargs.pop("image_start_end", None) if images is None: return None - image_input_idx = kwargs.pop("image_input_idx", None) - seq_len = kwargs.pop("seq_len", None) - if image_input_idx is None: - raise ValueError("image_input_idx is required for Molmo model.") - if seq_len is None: - raise ValueError("seq_len is required for Molmo model.") - if not isinstance(seq_len, torch.Tensor): - seq_len = torch.tensor(seq_len) + if not isinstance(images, (torch.Tensor, list)): + raise ValueError("Incorrect type of images. " + f"Got type: {type(images)}") + + image_masks = kwargs.pop("image_masks", None) + if not (image_masks is None or isinstance(image_masks, + (torch.Tensor, list))): + raise ValueError("Incorrect type of image_masks. " + f"Got type: {type(image_masks)}") + + feat_is_patch = kwargs.pop("feat_is_patch", None) + if not isinstance(feat_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of feat_is_patch. " + f"Got type: {type(feat_is_patch)}") + + embed_is_patch = kwargs.pop("embed_is_patch", None) + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + num_crops = kwargs.pop("num_crops", None) + if not isinstance(num_crops, torch.Tensor): + raise ValueError("Incorrect type of num_crops. " + f"Got type: {type(num_crops)}") + + img_patch_id = kwargs.pop("img_patch_id", None) + if not isinstance(img_patch_id, torch.Tensor): + raise ValueError("Incorrect type of num_crops. " + f"Got type: {type(num_crops)}") + self.img_patch_id = img_patch_id.flatten().unique().item() return MolmoImageInputs( images=images, - image_input_idx=image_input_idx, - seq_len=seq_len, image_masks=image_masks, - image_start_end=image_start_end, + feat_is_patch=feat_is_patch, + embed_is_patch=embed_is_patch, + num_crops=num_crops, ) def _process_image_input( self, image_input: MolmoImageInputs, - ) -> torch.Tensor: - - image_features = self.vision_backbone( - images=image_input["images"], - image_masks=image_input["image_masks"], - ) + ) -> Union[torch.Tensor, List[torch.Tensor]]: + if isinstance(image_input["images"], list): + # Call the vision backbone on the whole batch at once + images_flat = flatten_bn(image_input["images"], concat=True) + image_masks_flat = (None if (image_masks := + image_input["image_masks"]) is None + else flatten_bn(image_masks, concat=True)) + + image_features_flat = self.vision_backbone( + images=images_flat.unsqueeze(0), + image_masks=(None if image_masks_flat is None else + image_masks_flat.unsqueeze(0)), + ).squeeze(0) + + # Reconstruct the batch dimension + image_features = image_features_flat.split( + image_input["num_crops"].sum(-1).tolist()) + else: + image_features = self.vision_backbone( + images=image_input["images"], + image_masks=image_input["image_masks"], + ) return image_features + def _get_mm_embeds( + self, + features: torch.Tensor, # Shape: (num_crop, num_patch, d) + feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch) + num_crops: torch.Tensor, # Shape: (num_images,) + embed_is_patch: torch.Tensor, # Shape: (num_embeds,) + ) -> list[torch.Tensor]: + """ + Scatter the patch features into a contiguous tensor that corresponds + to the embedding tokens defined by the multimodal processor. + + Note: + The original code only considers patch tokens as feature + tokens, but our processor considers all image-related tokens + as feature tokens because the feature tokens need to be + consecutive in `input_ids`. + + Example: + A simplified example for one item in the batch: + + .. code-block:: + + Embedding tokens (from HF processor): + [ ] + + embed_is_patch (from HF processor): + [ False True True False True True False False ] + + Encoder outputs (from model): + [ p1 p2 0 p3 p4 0 ] + + feat_is_patch (from HF processor): + [ True True False True True False ] + + The resulting embedding tensor is: + [ nan p1 p2 nan p3 p4 nan nan ] + """ + num_crops_per_image = num_crops.tolist() + feats_per_image = features.split(num_crops_per_image) + f_is_patch_per_image = feat_is_patch.split(num_crops_per_image) + + _, _, embed_dim = features.shape + (num_embeds, ) = embed_is_patch.shape + + embeds_in_batch = list[torch.Tensor]() + for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image): + embeds = feats.new_full((num_embeds, embed_dim), torch.nan) + embeds[embed_is_patch] = feats[f_is_patch] + embeds_in_batch.append(embeds) + + return embeds_in_batch + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None + image_features = self._process_image_input(image_input) - image_input_idx = image_input["image_input_idx"] - seq_len = image_input["seq_len"] - batch_size, num_image, num_patch = image_features.shape[:3] - assert image_input_idx.shape == (batch_size, num_image, num_patch) - - # insert the image feature into the embedding. - image_features = image_features.view(batch_size, num_image * num_patch, - -1) - image_input_idx = image_input_idx.view(batch_size, - num_image * num_patch) - - valid = image_input_idx >= 0 - image_features = image_features * valid[:, :, None].to( - image_features.dtype) - image_features = image_features.view( - batch_size * num_image * num_patch, -1).contiguous() - - image_input_idx = image_input_idx * valid.to(image_input_idx.dtype) - offset = torch.cat([seq_len.new_zeros(1), - seq_len.cumsum(dim=0)[:-1]], - dim=0)[:, None] - image_input_idx = image_input_idx + offset.to(image_input_idx.dtype) - image_input_idx = image_input_idx.flatten()[:, None] - mat = image_input_idx == torch.arange( - seq_len.sum().item(), device=image_features.device)[None, :] - mat = mat.to(image_features.dtype) - - # Note: In this original implementation from AI2, the final - # vision_embeddings will be always be the same length - # of input embeddings. - vision_embeddings = torch.einsum('nd,nm->md', image_features, mat) - - # Split by the sizes of the input sequences. For each full embedding, - # extract the actual vision embeddings to be merged. - vision_embeddings = list(vision_embeddings.split(seq_len.tolist())) - for i in range(len(vision_embeddings)): - start, end = image_input['image_start_end'][i] - vision_embeddings[i] = vision_embeddings[i][start:end] - - return vision_embeddings + + return [ + self._get_mm_embeds(*args) for args in zip( + image_features, + image_input["feat_is_patch"], + image_input["num_crops"], + image_input["embed_is_patch"], + ) + ] def get_input_embeddings( self, @@ -1311,11 +1641,20 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: + assert self.img_patch_id is not None + + # Extract the patch tokens scattered in _get_mm_embeds + patch_embeddings = json_map_leaves( + lambda x: x[~x.isnan()].view(-1, *x.shape[1:]), + cast(JSONTree[torch.Tensor], multimodal_embeddings), + ) + inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, [ - DEFAULT_IMAGE_PATCH_TOKEN_ID, DEFAULT_IM_START_TOKEN_ID, - DEFAULT_IM_END_TOKEN_ID, DEFAULT_IM_COL_TOKEN_ID - ]) + input_ids, + inputs_embeds, + cast(NestedTensors, patch_embeddings), + self.img_patch_id, + ) return inputs_embeds def forward( diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 25ca8d1e71f7d..e93fa24a6e4dc 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -353,17 +353,17 @@ def batched(modality: str): Example: - .. code-block:: + .. code-block:: - Input: - Data: [[AAAA] - [BBBB] - [CCCC]] + Input: + Data: [[AAAA] + [BBBB] + [CCCC]] - Output: - Element 1: [AAAA] - Element 2: [BBBB] - Element 3: [CCCC] + Output: + Element 1: [AAAA] + Element 2: [BBBB] + Element 3: [CCCC] """ return MultiModalFieldConfig( field=MultiModalBatchedField(), @@ -384,18 +384,18 @@ def flat(modality: str, slices: Sequence[slice]): Example: - .. code-block:: - - Given: - slices: [slice(0, 3), slice(3, 7), slice(7, 9)] + .. code-block:: + + Given: + slices: [slice(0, 3), slice(3, 7), slice(7, 9)] - Input: - Data: [AAABBBBCC] + Input: + Data: [AAABBBBCC] - Output: - Element 1: [AAA] - Element 2: [BBBB] - Element 3: [CC] + Output: + Element 1: [AAA] + Element 2: [BBBB] + Element 3: [CC] """ return MultiModalFieldConfig( field=MultiModalFlatField(slices=slices), @@ -416,18 +416,18 @@ def flat_from_sizes(modality: str, size_per_item: torch.Tensor): Example: - .. code-block:: - - Given: - size_per_item: [3, 4, 2] + .. code-block:: + + Given: + size_per_item: [3, 4, 2] - Input: - Data: [AAABBBBCC] + Input: + Data: [AAABBBBCC] - Output: - Element 1: [AAA] - Element 2: [BBBB] - Element 3: [CC] + Output: + Element 1: [AAA] + Element 2: [BBBB] + Element 3: [CC] See also: :func:`MultiModalFieldConfig.flat` @@ -456,19 +456,19 @@ def shared(modality: str, batch_size: int): Example: - .. code-block:: - - Given: - batch_size: 4 + .. code-block:: + + Given: + batch_size: 4 - Input: - Data: [XYZ] + Input: + Data: [XYZ] - Output: - Element 1: [XYZ] - Element 2: [XYZ] - Element 3: [XYZ] - Element 4: [XYZ] + Output: + Element 1: [XYZ] + Element 2: [XYZ] + Element 3: [XYZ] + Element 4: [XYZ] """ return MultiModalFieldConfig( field=MultiModalSharedField(batch_size), diff --git a/vllm/utils.py b/vllm/utils.py index 6a41afff8f04c..79981fa0953a1 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -33,8 +33,7 @@ from functools import cache, lru_cache, partial, wraps from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Dict, Generator, Generic, Iterator, List, Literal, - NamedTuple, Optional, Tuple, Type, TypeVar, Union, - overload) + NamedTuple, Optional, Tuple, Type, TypeVar, Union) from uuid import uuid4 import cloudpickle @@ -826,38 +825,6 @@ def is_list_of( """A nested JSON structure where the leaves need not be JSON-serializable.""" -@overload -def json_map_leaves( - func: Callable[[T], U], - value: Dict[str, JSONTree[T]], -) -> Dict[str, JSONTree[U]]: - ... - - -@overload -def json_map_leaves( - func: Callable[[T], U], - value: List[JSONTree[T]], -) -> List[JSONTree[U]]: - ... - - -@overload -def json_map_leaves( - func: Callable[[T], U], - value: Tuple[JSONTree[T], ...], -) -> Tuple[JSONTree[U], ...]: - ... - - -@overload -def json_map_leaves( - func: Callable[[T], U], - value: JSONTree[T], -) -> JSONTree[U]: - ... - - def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]: if isinstance(value, dict): return {k: json_map_leaves(func, v) for k, v in value.items()}