Fix vlm ?

This commit is contained in:
Nicolas Patry 2024-09-30 15:45:02 +02:00
parent d9fecec000
commit e5476dc04c
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
2 changed files with 13 additions and 20 deletions

View File

@ -48,7 +48,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
bias=True, bias=True,
) )
self.vocab_size = config.vocab_size self.vocab_size = config.text_config.vocab_size
self.config = config self.config = config
text_config = config.text_config text_config = config.text_config

View File

@ -13,7 +13,6 @@ from text_generation_server.models.flash_causal_lm import (
FlashCausalLM, FlashCausalLM,
block_tables_to_ragged, block_tables_to_ragged,
) )
from loguru import logger
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
@ -58,6 +57,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
elif config.model_type == "llava_next": elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id] height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
from loguru import logger
log_master( log_master(
logger.info, logger.info,
@ -135,13 +135,13 @@ def get_number_of_features(height: int, width: int, config) -> int:
class VlmCausalLMBatch(FlashCausalLMBatch): class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]] pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]] = None pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]] = None image_sizes: Optional[List[Tuple[int, int]]]
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches): def concatenate(cls, batches):
batch = super().concatenate(batches) batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
@ -378,17 +378,6 @@ class VlmCausalLM(FlashCausalLM):
max_q=max_s, max_q=max_s,
max_k=max_k, max_k=max_k,
) )
if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
)
batch.cross_attention_states = cross_attention_states
cross_attention_states = batch.cross_attention_states
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -400,14 +389,18 @@ class VlmCausalLM(FlashCausalLM):
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states, pixel_values=batch.pixel_values,
adapter_data=adapter_data, pixel_attention_mask=batch.pixel_attention_mask,
image_indices=batch.image_indices[:], image_sizes=batch.image_sizes,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
if batch.pixel_values is not None: if batch.pixel_values is not None:
batch.pixel_values = None batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph
@ -425,7 +418,7 @@ class VlmCausalLM(FlashCausalLM):
cuda_graph["block_tables"][ cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1] : block_tables.shape[0], : block_tables.shape[1]
] = block_tables ] = block_tables
cuda_graph["slots"].fill_(0) cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( cuda_graph["input_lengths"][: input_lengths.shape[0]] = (