From 47d08636806c20047ef27ed2b670ea6e1f0afe07 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 18 Jul 2024 01:20:43 -0700 Subject: [PATCH] update according to review comment Signed-off-by: Wang, Yi A --- .../models/custom_modeling/flash_pali_gemma_modeling.py | 3 +-- .../text_generation_server/models/custom_modeling/idefics2.py | 3 +-- .../models/custom_modeling/llava_next.py | 3 +-- server/text_generation_server/models/vlm_causal_lm.py | 2 ++ 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 62ddf716..1f998e5a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -24,7 +24,6 @@ from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, ) -from text_generation_server.layers.attention import Seqlen class PaliGemmaForConditionalGeneration(nn.Module): @@ -93,7 +92,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): # insert image features into input embeddings inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) - input_lengths = Seqlen(input_lengths=input_lengths) + hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index a24e5efc..58199533 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -35,7 +35,6 @@ from text_generation_server.layers import ( TensorParallelRowLinear, ) from text_generation_server.utils.weights import DefaultWeightsLoader -from text_generation_server.layers.attention import Seqlen def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -825,7 +824,7 @@ class Idefics2ForConditionalGeneration(nn.Module): inputs_embeds = self._merge_input_ids_with_image_features( input_ids, inputs_embeds, image_hidden_states ) - input_lengths = Seqlen(input_lengths=input_lengths) + hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 9fc60733..e154d805 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -31,7 +31,6 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelRowLinear, ) -from text_generation_server.layers.attention import Seqlen def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): @@ -269,7 +268,7 @@ class LlavaNextForConditionalGeneration(nn.Module): inputs_embeds = self._merge_input_ids_with_image_features( input_ids, inputs_embeds, image_features ) - input_lengths = Seqlen(input_lengths=input_lengths) + hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index ace48805..ceb59f5a 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -14,6 +14,7 @@ from text_generation_server.models.flash_causal_lm import ( FlashCausalLM, ) from transformers import AutoProcessor +from text_generation_server.layers.attention import Seqlen tracer = trace.get_tracer(__name__) @@ -342,6 +343,7 @@ class VlmCausalLM(FlashCausalLM): else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: + input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids,