mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fix vlm ?
This commit is contained in:
parent
d9fecec000
commit
e5476dc04c
@ -48,7 +48,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.text_config.vocab_size
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
text_config = config.text_config
|
text_config = config.text_config
|
||||||
|
@ -13,7 +13,6 @@ from text_generation_server.models.flash_causal_lm import (
|
|||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
block_tables_to_ragged,
|
block_tables_to_ragged,
|
||||||
)
|
)
|
||||||
from loguru import logger
|
|
||||||
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
@ -58,6 +57,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
elif config.model_type == "llava_next":
|
elif config.model_type == "llava_next":
|
||||||
height, width = image_input["image_sizes"][image_id]
|
height, width = image_input["image_sizes"][image_id]
|
||||||
num_features = get_number_of_features(height, width, config)
|
num_features = get_number_of_features(height, width, config)
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
log_master(
|
log_master(
|
||||||
logger.info,
|
logger.info,
|
||||||
@ -135,13 +135,13 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||||||
|
|
||||||
class VlmCausalLMBatch(FlashCausalLMBatch):
|
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
pixel_attention_mask: Optional[List[torch.Tensor]] = None
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
image_sizes: Optional[List[Tuple[int, int]]] = None
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches):
|
def concatenate(cls, batches):
|
||||||
batch = super().concatenate(batches)
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
@ -378,17 +378,6 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
max_q=max_s,
|
max_q=max_s,
|
||||||
max_k=max_k,
|
max_k=max_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch.pixel_values is not None:
|
|
||||||
cross_attention_states = self.model.vision_forward(
|
|
||||||
pixel_values=batch.pixel_values,
|
|
||||||
aspect_ratio_ids=batch.aspect_ratio_ids,
|
|
||||||
aspect_ratio_mask=batch.aspect_ratio_mask,
|
|
||||||
)
|
|
||||||
batch.cross_attention_states = cross_attention_states
|
|
||||||
|
|
||||||
cross_attention_states = batch.cross_attention_states
|
|
||||||
|
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -400,14 +389,18 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
cross_attention_states=cross_attention_states,
|
pixel_values=batch.pixel_values,
|
||||||
adapter_data=adapter_data,
|
pixel_attention_mask=batch.pixel_attention_mask,
|
||||||
image_indices=batch.image_indices[:],
|
image_sizes=batch.image_sizes,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
if batch.pixel_attention_mask is not None:
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
if batch.image_sizes is not None:
|
||||||
|
batch.image_sizes = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
@ -425,7 +418,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
cuda_graph["block_tables"][
|
cuda_graph["block_tables"][
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
] = block_tables
|
] = block_tables
|
||||||
cuda_graph["slots"].fill_(0)
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||||
|
Loading…
Reference in New Issue
Block a user