From d8e3a276486b972d18a33330344ec0ce8a15f91b Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 17 Jul 2024 22:21:11 -0700 Subject: [PATCH] fix crash in multi-modal Signed-off-by: Wang, Yi A --- .../models/custom_modeling/flash_pali_gemma_modeling.py | 3 ++- .../text_generation_server/models/custom_modeling/idefics2.py | 4 +++- .../models/custom_modeling/llava_next.py | 4 +++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 1f998e5a..62ddf716 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -24,6 +24,7 @@ from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, ) +from text_generation_server.layers.attention import Seqlen class PaliGemmaForConditionalGeneration(nn.Module): @@ -92,7 +93,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): # insert image features into input embeddings inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) - + input_lengths = Seqlen(input_lengths=input_lengths) hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index daf3329a..a24e5efc 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -35,6 +35,7 @@ from text_generation_server.layers import ( TensorParallelRowLinear, ) from text_generation_server.utils.weights import DefaultWeightsLoader +from text_generation_server.layers.attention import Seqlen def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -824,7 +825,7 @@ class Idefics2ForConditionalGeneration(nn.Module): inputs_embeds = self._merge_input_ids_with_image_features( input_ids, inputs_embeds, image_hidden_states ) - + input_lengths = Seqlen(input_lengths=input_lengths) hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, @@ -836,6 +837,7 @@ class Idefics2ForConditionalGeneration(nn.Module): max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, + adapter_data=adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 567131ef..9fc60733 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -31,6 +31,7 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelRowLinear, ) +from text_generation_server.layers.attention import Seqlen def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): @@ -268,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module): inputs_embeds = self._merge_input_ids_with_image_features( input_ids, inputs_embeds, image_features ) - + input_lengths = Seqlen(input_lengths=input_lengths) hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, @@ -280,6 +281,7 @@ class LlavaNextForConditionalGeneration(nn.Module): max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, + adapter_data=adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices]