fix crash in multi-modal

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-07-17 22:21:11 -07:00
parent da82c63a4f
commit d8e3a27648
3 changed files with 8 additions and 3 deletions

View File

@ -24,6 +24,7 @@ from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
load_vision_model, load_vision_model,
) )
from text_generation_server.layers.attention import Seqlen
class PaliGemmaForConditionalGeneration(nn.Module): class PaliGemmaForConditionalGeneration(nn.Module):
@ -92,7 +93,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
# insert image features into input embeddings # insert image features into input embeddings
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
input_lengths = Seqlen(input_lengths=input_lengths)
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -35,6 +35,7 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation_server.utils.weights import DefaultWeightsLoader from text_generation_server.utils.weights import DefaultWeightsLoader
from text_generation_server.layers.attention import Seqlen
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -824,7 +825,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
inputs_embeds = self._merge_input_ids_with_image_features( inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states input_ids, inputs_embeds, image_hidden_states
) )
input_lengths = Seqlen(input_lengths=input_lengths)
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
@ -836,6 +837,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
max_s=max_s, max_s=max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
adapter_data=adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -31,6 +31,7 @@ from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation_server.layers.attention import Seqlen
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
@ -268,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
inputs_embeds = self._merge_input_ids_with_image_features( inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features input_ids, inputs_embeds, image_features
) )
input_lengths = Seqlen(input_lengths=input_lengths)
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
@ -280,6 +281,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
max_s=max_s, max_s=max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
adapter_data=adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]