From e5476dc04c1f3ec2b3431dadba2b700ee67febf0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 30 Sep 2024 15:45:02 +0200 Subject: [PATCH] Fix vlm ? --- .../flash_pali_gemma_modeling.py | 2 +- .../models/vlm_causal_lm.py | 31 +++++++------------ 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index d044b492..0024f2bb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -48,7 +48,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): bias=True, ) - self.vocab_size = config.vocab_size + self.vocab_size = config.text_config.vocab_size self.config = config text_config = config.text_config diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index bcb33c35..7f7d2e4d 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -13,7 +13,6 @@ from text_generation_server.models.flash_causal_lm import ( FlashCausalLM, block_tables_to_ragged, ) -from loguru import logger from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.utils.log import log_master from transformers import AutoProcessor @@ -58,6 +57,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) + from loguru import logger log_master( logger.info, @@ -135,13 +135,13 @@ def get_number_of_features(height: int, width: int, config) -> int: class VlmCausalLMBatch(FlashCausalLMBatch): pixel_values: Optional[List[torch.Tensor]] - pixel_attention_mask: Optional[List[torch.Tensor]] = None - image_sizes: Optional[List[Tuple[int, int]]] = None + pixel_attention_mask: Optional[List[torch.Tensor]] + image_sizes: Optional[List[Tuple[int, int]]] @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches): - batch = super().concatenate(batches) + batch = super(VlmCausalLMBatch, cls).concatenate(batches) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None @@ -378,17 +378,6 @@ class VlmCausalLM(FlashCausalLM): max_q=max_s, max_k=max_k, ) - - if batch.pixel_values is not None: - cross_attention_states = self.model.vision_forward( - pixel_values=batch.pixel_values, - aspect_ratio_ids=batch.aspect_ratio_ids, - aspect_ratio_mask=batch.aspect_ratio_mask, - ) - batch.cross_attention_states = cross_attention_states - - cross_attention_states = batch.cross_attention_states - logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -400,14 +389,18 @@ class VlmCausalLM(FlashCausalLM): max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, - cross_attention_states=cross_attention_states, - adapter_data=adapter_data, - image_indices=batch.image_indices[:], + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph @@ -425,7 +418,7 @@ class VlmCausalLM(FlashCausalLM): cuda_graph["block_tables"][ : block_tables.shape[0], : block_tables.shape[1] ] = block_tables - cuda_graph["slots"].fill_(0) + cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = (