|
31 | 31 | from vllm.model_executor.models.utils import merge_multimodal_embeddings
|
32 | 32 | from vllm.model_executor.sampling_metadata import SamplingMetadata
|
33 | 33 | from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
34 |
| -from vllm.multimodal.inputs import PlaceholderRange |
| 34 | +from vllm.multimodal.inputs import NestedTensors, PlaceholderRange |
35 | 35 | from vllm.multimodal.utils import (cached_get_tokenizer,
|
36 | 36 | consecutive_placeholder_ranges,
|
37 | 37 | resolve_visual_encoder_outputs)
|
@@ -190,38 +190,47 @@ def sampler(self):
|
190 | 190 |
|
191 | 191 | return get_sampler()
|
192 | 192 |
|
| 193 | + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: |
| 194 | + image_input = self._parse_and_validate_image_input(**kwargs) |
| 195 | + if image_input is None: |
| 196 | + return None |
| 197 | + vision_embeddings = self._process_image_input(image_input) |
| 198 | + return vision_embeddings |
| 199 | + |
| 200 | + def get_input_embeddings( |
| 201 | + self, |
| 202 | + input_ids: torch.Tensor, |
| 203 | + multimodal_embeddings: Optional[NestedTensors] = None, |
| 204 | + ) -> torch.Tensor: |
| 205 | + inputs_embeds = self.language_model.get_input_embeddings(input_ids) |
| 206 | + if multimodal_embeddings is not None: |
| 207 | + inputs_embeds = merge_multimodal_embeddings( |
| 208 | + input_ids, inputs_embeds, multimodal_embeddings, |
| 209 | + self.vision_args.image_token_id) |
| 210 | + return inputs_embeds |
| 211 | + |
193 | 212 | def forward(
|
194 | 213 | self,
|
195 | 214 | input_ids: torch.Tensor,
|
196 | 215 | positions: torch.Tensor,
|
197 | 216 | kv_caches: List[torch.Tensor],
|
198 | 217 | attn_metadata: AttentionMetadata,
|
199 | 218 | intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 219 | + inputs_embeds: Optional[torch.Tensor] = None, |
200 | 220 | **kwargs: object,
|
201 | 221 | ) -> Union[torch.Tensor, IntermediateTensors]:
|
202 | 222 | """Run forward pass for pixtral.
|
203 |
| -
|
204 |
| - TODO |
205 |
| -
|
206 | 223 | """
|
207 | 224 | if intermediate_tensors is not None:
|
208 |
| - input_ids = None |
209 | 225 | inputs_embeds = None
|
210 |
| - else: |
211 |
| - image_input = self._parse_and_validate_image_input(**kwargs) |
212 |
| - |
213 |
| - if image_input is not None: |
214 |
| - vision_embeddings = self._process_image_input(image_input) |
215 |
| - inputs_embeds = self.language_model.model.get_input_embeddings( |
216 |
| - input_ids) |
217 | 226 |
|
218 |
| - inputs_embeds = merge_multimodal_embeddings( |
219 |
| - input_ids, inputs_embeds, vision_embeddings, |
220 |
| - self.vision_args.image_token_id) |
221 |
| - |
222 |
| - input_ids = None |
223 |
| - else: |
224 |
| - inputs_embeds = None |
| 227 | + # NOTE: In v1, inputs_embeds is always generated at model runner, this |
| 228 | + # condition is for v0 compatibility. |
| 229 | + elif inputs_embeds is None: |
| 230 | + vision_embeddings = self.get_multimodal_embeddings(**kwargs) |
| 231 | + inputs_embeds = self.get_input_embeddings(input_ids, |
| 232 | + vision_embeddings) |
| 233 | + input_ids = None |
225 | 234 |
|
226 | 235 | hidden_states = self.language_model.model(input_ids,
|
227 | 236 | positions,
|
|
0 commit comments