mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
add improvements
This commit is contained in:
parent
7237e8e6bf
commit
be8e60a918
@ -39,7 +39,7 @@ httpcore==1.0.7
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via openai
|
||||
huggingface-hub==0.29.3
|
||||
huggingface-hub==0.30.1
|
||||
# via
|
||||
# text-generation-integration-tests (pyproject.toml)
|
||||
# text-generation
|
||||
|
@ -128,9 +128,6 @@ try:
|
||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||
FlashGPTNeoXForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.pali_gemma import (
|
||||
PaliGemmaBatch,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||
PaliGemmaForConditionalGeneration,
|
||||
)
|
||||
@ -1196,6 +1193,7 @@ def get_model(
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
support_chunking=False,
|
||||
)
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
from transformers import Gemma3ForConditionalGeneration as Gemma3Model
|
||||
@ -1208,6 +1206,7 @@ def get_model(
|
||||
speculator=speculator,
|
||||
dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
support_chunking=False,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
|
||||
@ -1583,6 +1582,7 @@ def get_model(
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
support_chunking=False,
|
||||
)
|
||||
# TODO: Uncomment when transformers is refactored and cross attn is added
|
||||
# elif FLASH_TRANSFORMERS_BACKEND:
|
||||
@ -1676,7 +1676,6 @@ def get_model(
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
batch_class=PaliGemmaBatch,
|
||||
)
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel
|
||||
@ -1689,7 +1688,6 @@ def get_model(
|
||||
speculator=speculator,
|
||||
dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
batch_class=PaliGemmaBatch,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma"))
|
||||
|
@ -700,6 +700,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
self.dtype = weights.dtype
|
||||
|
||||
def get_attention_mask(
|
||||
self,
|
||||
@ -762,6 +763,38 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
else:
|
||||
return torch.where(full_attention_mask, 0, min_dtype).to(device)
|
||||
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
**kwargs,
|
||||
):
|
||||
pixel_values = pixel_values.to(dtype=self.dtype)
|
||||
image_outputs = self.vision_model(pixel_values)
|
||||
vision_outputs = self.post_vision_model_layernorm(
|
||||
image_outputs.last_hidden_state
|
||||
)
|
||||
image_features = self.multimodal_projector(vision_outputs)
|
||||
image_features = image_features.view(-1, image_features.shape[-1])
|
||||
return image_features
|
||||
|
||||
def get_input_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
|
||||
if vision_embeds is not None:
|
||||
# Replace the image token embeddings with the vision features
|
||||
image_token_mask = (input_ids == self.config.image_token_index).to(
|
||||
input_ids.device
|
||||
)
|
||||
inputs_embeds[image_token_mask] = vision_embeds.view(
|
||||
-1, vision_embeds.shape[-1]
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -781,26 +814,17 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if cu_seqlen_prefill is not None:
|
||||
max_s += 1
|
||||
position_ids += 1
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
||||
image_outputs = self.vision_model(pixel_values)
|
||||
vision_outputs = self.post_vision_model_layernorm(
|
||||
image_outputs.last_hidden_state
|
||||
)
|
||||
image_features = self.multimodal_projector(vision_outputs)
|
||||
|
||||
image_token_mask = (input_ids == self.config.image_token_index).to(
|
||||
input_ids.device
|
||||
)
|
||||
inputs_embeds[image_token_mask] = image_features.view(
|
||||
-1, image_features.shape[-1]
|
||||
)
|
||||
|
||||
if torch.any(image_token_mask):
|
||||
attention_mask = self.get_attention_mask(
|
||||
input_ids,
|
||||
cu_seqlen_prefill,
|
||||
|
@ -62,6 +62,37 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
self.dtype = weights.dtype
|
||||
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
**kwargs,
|
||||
):
|
||||
pixel_values = pixel_values.to(dtype=self.dtype)
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
last_hidden_state = self.post_vision_tower_layernorm(
|
||||
image_outputs.last_hidden_state
|
||||
)
|
||||
image_features = self.multi_modal_projector(last_hidden_state)
|
||||
image_features = image_features.view(
|
||||
image_features.shape[0], image_features.shape[1], -1
|
||||
)
|
||||
return image_features
|
||||
|
||||
def get_input_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
|
||||
if vision_embeds is not None:
|
||||
mask = input_ids == self.config.image_token_index
|
||||
inputs_embeds[mask] = vision_embeds.view(-1, vision_embeds.shape[-1])
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -81,27 +112,13 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||
if cu_seqlen_prefill is not None:
|
||||
max_s += 1
|
||||
position_ids += 1
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
last_hidden_state = self.post_vision_tower_layernorm(
|
||||
image_outputs.last_hidden_state
|
||||
)
|
||||
image_features = self.multi_modal_projector(last_hidden_state)
|
||||
|
||||
# mask where image or padding tokens
|
||||
mask = input_ids == self.config.image_token_index
|
||||
|
||||
# insert image features into input embeddings
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
|
@ -476,38 +476,18 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
# Unused here
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: torch.BoolTensor,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
all_states = []
|
||||
all_pixel_values = pixel_values
|
||||
all_pixel_mask = pixel_attention_mask
|
||||
for i in range(batch_size):
|
||||
pixel_values = all_pixel_values.to(
|
||||
dtype=self.dtype
|
||||
) # fp16 compatibility
|
||||
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||
pixel_values = pixel_values[i : i + 1]
|
||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||
|
||||
@ -561,10 +541,54 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
all_states.append(image_hidden_states)
|
||||
image_hidden_states = torch.stack(all_states, dim=0)
|
||||
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_hidden_states
|
||||
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||
|
||||
def get_input_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: torch.BoolTensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if vision_embeds is None and pixel_values is not None:
|
||||
vision_embeds = self.get_vision_embeds(
|
||||
pixel_values=pixel_values,
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
)
|
||||
|
||||
if vision_embeds is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, vision_embeds
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
# Unused here
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
|
@ -163,27 +163,12 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused for this model
|
||||
pixel_attention_mask=None,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
# 1. Extract the input embeddings
|
||||
@ -218,8 +203,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
new_image_features = []
|
||||
@ -256,9 +240,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
@ -266,11 +248,51 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
return image_features.view(-1, image_features.shape[-1])
|
||||
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_features
|
||||
def get_input_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if vision_embeds is None and pixel_values is not None:
|
||||
vision_embeds = self.get_vision_embeds(
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
if vision_embeds is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, vision_embeds
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused for this model
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
|
@ -959,6 +959,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
# XXX: Putting these as optional so that the cuda warmup calls can go through.
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
if cross_attention_states is not None:
|
||||
seqlen_q = len(image_indices)
|
||||
|
@ -922,6 +922,29 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
)
|
||||
return position_ids
|
||||
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
||||
return image_embeds
|
||||
|
||||
def get_input_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
if vision_embeds is not None and len(vision_embeds) > 0:
|
||||
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -943,17 +966,8 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.visual(
|
||||
pixel_values, grid_thw=image_grid_thw
|
||||
).squeeze(0)
|
||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
||||
|
||||
hidden_states = self.text_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
|
@ -500,6 +500,29 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
)
|
||||
return position_ids
|
||||
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
||||
return image_embeds
|
||||
|
||||
def get_input_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
if vision_embeds is not None and len(vision_embeds) > 0:
|
||||
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -520,17 +543,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.visual(
|
||||
pixel_values, grid_thw=image_grid_thw
|
||||
).squeeze(0)
|
||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
||||
|
||||
hidden_states = self.text_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
|
@ -29,6 +29,9 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||
aspect_ratio_mask: Optional[torch.Tensor] = None
|
||||
cross_attention_states: Optional[torch.Tensor] = None
|
||||
|
||||
def prepare_for_prefill(self):
|
||||
super(VlmCausalLMBatch, self).prepare_for_prefill()
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches):
|
||||
@ -196,6 +199,9 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||
|
||||
|
||||
class MllamaCausalLM(VlmCausalLM):
|
||||
def get_input_embeddings(self, batch):
|
||||
batch.inputs_embeds = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: MllamaCausalLMBatch,
|
||||
|
@ -163,6 +163,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
processor_kwargs=None,
|
||||
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||
batch_class=VlmCausalLMBatch,
|
||||
support_chunking: bool = True,
|
||||
):
|
||||
self.batch_class = batch_class
|
||||
self.quantize = quantize
|
||||
@ -304,7 +305,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
support_chunking=True,
|
||||
support_chunking=support_chunking,
|
||||
)
|
||||
|
||||
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
|
||||
@ -339,6 +340,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
trust_remote_code: bool = False,
|
||||
batch_class: Optional[type] = VlmCausalLMBatch,
|
||||
processor_kwargs: Optional[dict] = None,
|
||||
support_chunking: bool = True,
|
||||
):
|
||||
return cls(
|
||||
model_id=model_id,
|
||||
@ -350,6 +352,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
batch_class=batch_class,
|
||||
processor_kwargs=processor_kwargs,
|
||||
support_chunking=support_chunking,
|
||||
)
|
||||
|
||||
def _model_forward(
|
||||
|
@ -13,7 +13,7 @@ from text_generation_server.models.flash_causal_lm import (
|
||||
FlashCausalLMBatch,
|
||||
FlashCausalLM,
|
||||
)
|
||||
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION, MEM_POOL
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.log import log_master
|
||||
from transformers import AutoProcessor
|
||||
@ -119,8 +119,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
return image_str, IDEFICS2_FAKE_TOKEN
|
||||
if config.model_type == "idefics3":
|
||||
# TODO: implement this in a more general way
|
||||
n_rows = image_input["rows"][0][image_id]
|
||||
n_cols = image_input["cols"][0][image_id]
|
||||
n_rows = image_input[image_id]["rows"][0][0]
|
||||
n_cols = image_input[image_id]["cols"][0][0]
|
||||
image_seq_len = int(
|
||||
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
|
||||
/ (config.scale_factor**2)
|
||||
@ -135,7 +135,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
)
|
||||
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
|
||||
elif config.model_type == "llava_next":
|
||||
height, width = image_input["image_sizes"][image_id]
|
||||
height, width = image_input[image_id]["image_sizes"][0]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
|
||||
log_master(
|
||||
@ -147,12 +147,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
elif config.model_type == "paligemma":
|
||||
return "<image>" * config.text_config.num_image_tokens, "<image>"
|
||||
elif config.model_type == "qwen2_vl":
|
||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
||||
grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0]
|
||||
num_pads = grid_t * grid_h * grid_w // 4
|
||||
padding = "<|image_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||
elif config.model_type == "qwen2_5_vl":
|
||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
||||
grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0]
|
||||
num_pads = grid_t * grid_h * grid_w // 4
|
||||
padding = "<|image_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||
@ -344,8 +344,155 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
def batch_tokenized_inputs(
|
||||
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
||||
):
|
||||
# Process images first. We need all of them so that the processor
|
||||
# can make the image splits the same size. And we need the final
|
||||
kwargs = {}
|
||||
if (
|
||||
hasattr(processor, "image_processor_class")
|
||||
and processor.image_processor_class == "Idefics3ImageProcessor"
|
||||
):
|
||||
kwargs["return_row_col_info"] = True
|
||||
|
||||
max_length = 0
|
||||
vocab = tokenizer.get_vocab()
|
||||
config.image_token_index = (
|
||||
config.image_token_index
|
||||
if hasattr(config, "image_token_index")
|
||||
else config.image_token_id
|
||||
)
|
||||
|
||||
batch_tokenized_inputs: List[List[int]] = []
|
||||
batch_image_inputs: List[Optional[List[dict]]] = []
|
||||
batch_image_positions: List[Optional[List[ImagePositions]]] = []
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
text_parts = []
|
||||
image_inputs = []
|
||||
image_texts = []
|
||||
|
||||
image_id = 0
|
||||
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
text_parts.append(chunk.text)
|
||||
continue
|
||||
|
||||
if chunk_type != "image":
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
img = Image.open(BytesIO(chunk.image.data))
|
||||
|
||||
if config.model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
|
||||
img = img.resize((img.width * 2, img.height * 2))
|
||||
|
||||
if config.model_type in {"paligemma"}:
|
||||
img = img.convert("RGB")
|
||||
|
||||
if config.model_type not in {"llava_next", "gemma3", "llama4"}:
|
||||
img = [img]
|
||||
|
||||
image_input = processor.image_processor(
|
||||
[img], return_tensors="pt", **kwargs
|
||||
)
|
||||
image_inputs.append(image_input)
|
||||
|
||||
img_text, id_token_str = image_text_replacement(
|
||||
processor, image_input, config, 0
|
||||
)
|
||||
|
||||
text_parts.append(img_text)
|
||||
|
||||
image_texts.append([image_id, id_token_str, img_text])
|
||||
image_id += 1
|
||||
|
||||
full_text = image_text_replacement_fixup(config, "".join(text_parts))
|
||||
input_ids = tokenizer(
|
||||
full_text,
|
||||
truncation=True,
|
||||
max_length=r.truncate,
|
||||
add_special_tokens=r.add_special_tokens,
|
||||
)["input_ids"]
|
||||
max_length = max(max_length, len(input_ids))
|
||||
|
||||
if len(image_inputs) > 0:
|
||||
img_start_token = vocab[image_texts[0][1]]
|
||||
image_positions = cls.get_image_positions(
|
||||
input_ids, image_texts, img_start_token, config, tokenizer
|
||||
)
|
||||
else:
|
||||
image_inputs = None
|
||||
image_positions = None
|
||||
|
||||
batch_tokenized_inputs.append(input_ids)
|
||||
batch_image_inputs.append(image_inputs)
|
||||
batch_image_positions.append(image_positions)
|
||||
|
||||
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
|
||||
|
||||
@classmethod
|
||||
def get_image_positions(
|
||||
cls,
|
||||
input_ids: List[int],
|
||||
image_texts: List[Tuple[int, str, str]],
|
||||
img_start_token: int,
|
||||
config,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[ImagePositions]:
|
||||
image_positions = []
|
||||
num_images = len(image_texts)
|
||||
|
||||
input_ids_t = torch.as_tensor(input_ids, dtype=torch.int32)
|
||||
img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
|
||||
|
||||
last_pos = 0
|
||||
for i in range(num_images):
|
||||
image_id, img_start_token_str, img_text = image_texts[i]
|
||||
img_text = image_text_replacement_fixup(config, img_text)
|
||||
if config.model_type == "gemma3":
|
||||
img_text = img_text.replace("\n\n", "")
|
||||
|
||||
tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"]
|
||||
|
||||
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
|
||||
index = img_start_token_pos[pos]
|
||||
|
||||
is_embed = torch.tensor(tokens) == config.image_token_index
|
||||
num_placeholder_tokens = is_embed.sum().item()
|
||||
|
||||
length = len(tokens)
|
||||
if num_placeholder_tokens == length:
|
||||
is_embed = None
|
||||
|
||||
pos = ImagePositions(
|
||||
offset=index,
|
||||
length=length,
|
||||
id=image_id,
|
||||
num_placeholder_tokens=num_placeholder_tokens,
|
||||
is_embed=is_embed,
|
||||
)
|
||||
|
||||
image_positions.append(pos)
|
||||
last_pos = index + length
|
||||
|
||||
if (
|
||||
config.model_type == "idefics2"
|
||||
and i + 1 != num_images
|
||||
and input_ids[last_pos] == config.image_token_index
|
||||
):
|
||||
fake_token = last_pos - 1
|
||||
fake_token_index = torch.searchsorted(
|
||||
img_start_token_pos, fake_token, right=False
|
||||
)
|
||||
img_start_token_pos[fake_token_index] = last_pos
|
||||
image_texts[i + 1][2] = image_texts[i + 1][2][
|
||||
len(img_start_token_str) :
|
||||
]
|
||||
|
||||
return image_positions
|
||||
|
||||
@classmethod
|
||||
def batch_tokenized_inputs2(
|
||||
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
||||
):
|
||||
# sizes to insert correct number of image tokens.
|
||||
kwargs = {}
|
||||
if (
|
||||
@ -374,21 +521,20 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
if config.model_type in {"llava_next", "gemma3", "llama4"}:
|
||||
image = image
|
||||
elif config.model_type in {"paligemma"}:
|
||||
image = image.convert("RGB")
|
||||
else:
|
||||
image = [image]
|
||||
pixel_values = processor.image_processor(
|
||||
image_input = processor.image_processor(
|
||||
[image], return_tensors="pt", **kwargs
|
||||
)
|
||||
|
||||
image_inputs.append(pixel_values)
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
if len(image_inputs) > 0:
|
||||
batch_image_inputs[i] = image_inputs
|
||||
# pixel_values = processor.image_processor(
|
||||
# all_images, return_tensors="pt", **kwargs
|
||||
# )
|
||||
|
||||
batch_image_positions = []
|
||||
batch_tokenized_inputs = []
|
||||
@ -554,29 +700,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
if image_id not in self.encoder_cache[i]:
|
||||
self.pixel_values.append((i, image_position, image_inputs))
|
||||
# scheduled_image_pixel_values.append(image_inputs)
|
||||
self.image_inputs[i][j] = None
|
||||
|
||||
# if self.has_image and len(scheduled_image_pixel_values):
|
||||
# self.pixel_values = [
|
||||
# d["pixel_values"].to(device) for d in scheduled_image_pixel_values
|
||||
# ]
|
||||
|
||||
# if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
|
||||
# self.pixel_attention_mask = [
|
||||
# d["pixel_attention_mask"].to(device)
|
||||
# for d in scheduled_image_pixel_values
|
||||
# ]
|
||||
|
||||
# if "image_sizes" in scheduled_image_pixel_values[0]:
|
||||
# self.image_sizes = [
|
||||
# d["image_sizes"].to(device) for d in scheduled_image_pixel_values
|
||||
# ]
|
||||
|
||||
# if "image_grid_thw" in scheduled_image_pixel_values[0]:
|
||||
# self.image_grid_thw = [
|
||||
# d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values
|
||||
# ]
|
||||
if not self.has_image:
|
||||
self.pixel_values = None
|
||||
self.pixel_attention_mask = None
|
||||
@ -637,12 +762,21 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
if is_embed is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
|
||||
from loguru import logger
|
||||
|
||||
logger.info(
|
||||
f"image_id {image_id} start_idx {start_idx} end_idx {end_idx}, length {length}"
|
||||
)
|
||||
|
||||
mm_embeds_item = gather_image_embeds(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
if mm_embeds_item is not None:
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
|
||||
if len(mm_embeds) == 0:
|
||||
return None
|
||||
return torch.cat(mm_embeds, dim=0).to(device)
|
||||
|
||||
def free_encoder_cache(self):
|
||||
@ -662,6 +796,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
batch_class=VlmCausalLMBatch,
|
||||
revision,
|
||||
trust_remote_code: bool,
|
||||
support_chunking: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if PREFIX_CACHING:
|
||||
@ -679,8 +814,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
# FIXME: VLM do not work with context chunking yet
|
||||
support_chunking=False,
|
||||
support_chunking=support_chunking,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -688,6 +822,153 @@ class VlmCausalLM(FlashCausalLM):
|
||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||
return self.batch_class
|
||||
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
|
||||
input_lengths = [max_s] * bs
|
||||
cache_lengths = [0] * bs
|
||||
if max_bs is None:
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
input_embeds = torch.zeros(
|
||||
(bs, self.model.config.text_config.hidden_size),
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
config = getattr(self.model, "config", None)
|
||||
rope_scaling = getattr(config, "rope_scaling", None) if config else None
|
||||
if ( # mrope have position_ids per section, if so repeat n times
|
||||
isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope"
|
||||
):
|
||||
n_sections = len(self.model.config.rope_scaling["mrope_section"])
|
||||
position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)
|
||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||
input_lengths_tensor = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
)
|
||||
cache_lengths_tensor = torch.zeros(
|
||||
bs, dtype=torch.int32, device=self.device
|
||||
)
|
||||
block_tables = torch.arange(
|
||||
max_bt, dtype=torch.int32, device=self.device
|
||||
).repeat(bs)
|
||||
block_tables = block_tables.reshape((bs, max_bt))
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
max_current_length=max_s,
|
||||
)
|
||||
else:
|
||||
if bs > max_bs:
|
||||
raise RuntimeError(
|
||||
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
||||
)
|
||||
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
||||
input_embeds = self.cuda_graphs[max_bs]["input_embeds"][:bs]
|
||||
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
||||
else:
|
||||
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
|
||||
slots = self.cuda_graphs[max_bs]["slots"][:bs]
|
||||
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
|
||||
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
|
||||
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
create_decode_state_cuda_graphs,
|
||||
)
|
||||
|
||||
block_tables_ptr = torch.zeros(
|
||||
bs + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||
state = create_decode_state_cuda_graphs(
|
||||
device=input_ids.device,
|
||||
block_tables=block_tables,
|
||||
block_tables_ptr=block_tables_ptr,
|
||||
last_page_len=last_page_len,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
else:
|
||||
state = None
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
"input_embeds": input_embeds,
|
||||
"position_ids": position_ids,
|
||||
"kv_cache": self.kv_cache,
|
||||
"block_tables": block_tables,
|
||||
"slots": slots,
|
||||
"input_lengths": input_lengths_tensor,
|
||||
"cache_lengths": cache_lengths_tensor,
|
||||
"state": state,
|
||||
"graph": graph,
|
||||
}
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# Run once outside to warmup
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
state=state,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
)
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=input_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
del seqlen
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=input_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
@ -901,6 +1182,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["input_embeds"][: inputs_embeds.shape[0]] = inputs_embeds
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
|
Loading…
Reference in New Issue
Block a user