update according to review comment

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-07-18 01:20:43 -07:00
parent d8e3a27648
commit 47d0863680
4 changed files with 5 additions and 6 deletions

View File

@ -24,7 +24,6 @@ from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
load_vision_model, load_vision_model,
) )
from text_generation_server.layers.attention import Seqlen
class PaliGemmaForConditionalGeneration(nn.Module): class PaliGemmaForConditionalGeneration(nn.Module):
@ -93,7 +92,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
# insert image features into input embeddings # insert image features into input embeddings
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
input_lengths = Seqlen(input_lengths=input_lengths)
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -35,7 +35,6 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation_server.utils.weights import DefaultWeightsLoader from text_generation_server.utils.weights import DefaultWeightsLoader
from text_generation_server.layers.attention import Seqlen
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -825,7 +824,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
inputs_embeds = self._merge_input_ids_with_image_features( inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states input_ids, inputs_embeds, image_hidden_states
) )
input_lengths = Seqlen(input_lengths=input_lengths)
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -31,7 +31,6 @@ from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation_server.layers.attention import Seqlen
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
@ -269,7 +268,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
inputs_embeds = self._merge_input_ids_with_image_features( inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features input_ids, inputs_embeds, image_features
) )
input_lengths = Seqlen(input_lengths=input_lengths)
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -14,6 +14,7 @@ from text_generation_server.models.flash_causal_lm import (
FlashCausalLM, FlashCausalLM,
) )
from transformers import AutoProcessor from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -342,6 +343,7 @@ class VlmCausalLM(FlashCausalLM):
else: else:
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,