Skip to content

Commit 626f391

Browse files
ywang96BKitor
authored andcommitted
[V1] Update interface for mistral-format Pixtral (vllm-project#10703)
Signed-off-by: Roger Wang <[email protected]>
1 parent 1de1373 commit 626f391

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

vllm/model_executor/models/pixtral.py

+28-19
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from vllm.model_executor.models.utils import merge_multimodal_embeddings
3232
from vllm.model_executor.sampling_metadata import SamplingMetadata
3333
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
34-
from vllm.multimodal.inputs import PlaceholderRange
34+
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
3535
from vllm.multimodal.utils import (cached_get_tokenizer,
3636
consecutive_placeholder_ranges,
3737
resolve_visual_encoder_outputs)
@@ -190,38 +190,47 @@ def sampler(self):
190190

191191
return get_sampler()
192192

193+
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
194+
image_input = self._parse_and_validate_image_input(**kwargs)
195+
if image_input is None:
196+
return None
197+
vision_embeddings = self._process_image_input(image_input)
198+
return vision_embeddings
199+
200+
def get_input_embeddings(
201+
self,
202+
input_ids: torch.Tensor,
203+
multimodal_embeddings: Optional[NestedTensors] = None,
204+
) -> torch.Tensor:
205+
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
206+
if multimodal_embeddings is not None:
207+
inputs_embeds = merge_multimodal_embeddings(
208+
input_ids, inputs_embeds, multimodal_embeddings,
209+
self.vision_args.image_token_id)
210+
return inputs_embeds
211+
193212
def forward(
194213
self,
195214
input_ids: torch.Tensor,
196215
positions: torch.Tensor,
197216
kv_caches: List[torch.Tensor],
198217
attn_metadata: AttentionMetadata,
199218
intermediate_tensors: Optional[IntermediateTensors] = None,
219+
inputs_embeds: Optional[torch.Tensor] = None,
200220
**kwargs: object,
201221
) -> Union[torch.Tensor, IntermediateTensors]:
202222
"""Run forward pass for pixtral.
203-
204-
TODO
205-
206223
"""
207224
if intermediate_tensors is not None:
208-
input_ids = None
209225
inputs_embeds = None
210-
else:
211-
image_input = self._parse_and_validate_image_input(**kwargs)
212-
213-
if image_input is not None:
214-
vision_embeddings = self._process_image_input(image_input)
215-
inputs_embeds = self.language_model.model.get_input_embeddings(
216-
input_ids)
217226

218-
inputs_embeds = merge_multimodal_embeddings(
219-
input_ids, inputs_embeds, vision_embeddings,
220-
self.vision_args.image_token_id)
221-
222-
input_ids = None
223-
else:
224-
inputs_embeds = None
227+
# NOTE: In v1, inputs_embeds is always generated at model runner, this
228+
# condition is for v0 compatibility.
229+
elif inputs_embeds is None:
230+
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
231+
inputs_embeds = self.get_input_embeddings(input_ids,
232+
vision_embeddings)
233+
input_ids = None
225234

226235
hidden_states = self.language_model.model(input_ids,
227236
positions,

0 commit comments

Comments
 (0)