mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-21 09:42:09 +00:00
Chunked Prefill VLM (#3188)
* add logic * working * add encoder cache free * fixes * fix idefics * update pixel_values * add improvements * add improvements * improve * nit * fix inputs_embeds * nit * optimizations * add prometheus port * rename vars * rename vars * nit * disable chunking for qwen * review comments * remove port * improve headdim * remove kwargs and redundant args * fix qwen2_5 * fix config image_token_id error * fix test * update paligemma * fix paligemma text * minor fix * fix qwen test * fix qwen test
This commit is contained in:
parent
533eee50dc
commit
329f612e55
@ -128,9 +128,6 @@ try:
|
|||||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
FlashGPTNeoXForCausalLM,
|
FlashGPTNeoXForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.pali_gemma import (
|
|
||||||
PaliGemmaBatch,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||||
PaliGemmaForConditionalGeneration,
|
PaliGemmaForConditionalGeneration,
|
||||||
)
|
)
|
||||||
@ -1196,6 +1193,7 @@ def get_model(
|
|||||||
default_dtype=torch.bfloat16,
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
support_chunking=False,
|
||||||
)
|
)
|
||||||
elif FLASH_TRANSFORMERS_BACKEND:
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
from transformers import Gemma3ForConditionalGeneration as Gemma3Model
|
from transformers import Gemma3ForConditionalGeneration as Gemma3Model
|
||||||
@ -1208,6 +1206,7 @@ def get_model(
|
|||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
support_chunking=False,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
|
||||||
@ -1523,6 +1522,8 @@ def get_model(
|
|||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
# TODO: Fix bug in rust image_text_replacement implementation
|
||||||
|
support_chunking=False,
|
||||||
)
|
)
|
||||||
# TODO: Uncomment when transformers is refactored
|
# TODO: Uncomment when transformers is refactored
|
||||||
# elif FLASH_TRANSFORMERS_BACKEND:
|
# elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
@ -1554,6 +1555,8 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
config_class=Qwen2_5_VLConfig,
|
config_class=Qwen2_5_VLConfig,
|
||||||
processor_class=Qwen2_5_VLProcessor,
|
processor_class=Qwen2_5_VLProcessor,
|
||||||
|
# TODO: Fix bug in rust image_text_replacement implementation
|
||||||
|
support_chunking=False,
|
||||||
)
|
)
|
||||||
# TODO: Uncomment when transformers is refactored
|
# TODO: Uncomment when transformers is refactored
|
||||||
# elif FLASH_TRANSFORMERS_BACKEND:
|
# elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
@ -1583,6 +1586,7 @@ def get_model(
|
|||||||
default_dtype=torch.bfloat16,
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
support_chunking=False,
|
||||||
)
|
)
|
||||||
# TODO: Uncomment when transformers is refactored and cross attn is added
|
# TODO: Uncomment when transformers is refactored and cross attn is added
|
||||||
# elif FLASH_TRANSFORMERS_BACKEND:
|
# elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
@ -1676,7 +1680,6 @@ def get_model(
|
|||||||
default_dtype=torch.bfloat16,
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
batch_class=PaliGemmaBatch,
|
|
||||||
)
|
)
|
||||||
elif FLASH_TRANSFORMERS_BACKEND:
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel
|
from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel
|
||||||
@ -1689,7 +1692,6 @@ def get_model(
|
|||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
batch_class=PaliGemmaBatch,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma"))
|
||||||
|
@ -700,6 +700,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
self.pad_token_id = (
|
self.pad_token_id = (
|
||||||
config.pad_token_id if config.pad_token_id is not None else -1
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
)
|
)
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
|
||||||
def get_attention_mask(
|
def get_attention_mask(
|
||||||
self,
|
self,
|
||||||
@ -762,9 +763,42 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return torch.where(full_attention_mask, 0, min_dtype).to(device)
|
return torch.where(full_attention_mask, 0, min_dtype).to(device)
|
||||||
|
|
||||||
def forward(
|
def get_vision_embeds(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
pixel_values = pixel_values.to(dtype=self.dtype)
|
||||||
|
image_outputs = self.vision_model(pixel_values)
|
||||||
|
vision_outputs = self.post_vision_model_layernorm(
|
||||||
|
image_outputs.last_hidden_state
|
||||||
|
)
|
||||||
|
image_features = self.multimodal_projector(vision_outputs)
|
||||||
|
image_features = image_features.view(-1, image_features.shape[-1])
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
vision_embeds: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if vision_embeds is not None:
|
||||||
|
# Replace the image token embeddings with the vision features
|
||||||
|
image_token_mask = (input_ids == self.config.image_token_index).to(
|
||||||
|
input_ids.device
|
||||||
|
)
|
||||||
|
inputs_embeds[image_token_mask] = vision_embeds.view(
|
||||||
|
-1, vision_embeds.shape[-1]
|
||||||
|
)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -777,35 +811,12 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
# Unused here
|
# Unused here
|
||||||
attention_mask: Optional[torch.BoolTensor] = None,
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
|
||||||
image_sizes: Optional[torch.Tensor] = None,
|
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
max_s += 1
|
max_s += 1
|
||||||
position_ids += 1
|
position_ids += 1
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
|
||||||
image_outputs = self.vision_model(pixel_values)
|
|
||||||
vision_outputs = self.post_vision_model_layernorm(
|
|
||||||
image_outputs.last_hidden_state
|
|
||||||
)
|
|
||||||
image_features = self.multimodal_projector(vision_outputs)
|
|
||||||
|
|
||||||
image_token_mask = (input_ids == self.config.image_token_index).to(
|
|
||||||
input_ids.device
|
|
||||||
)
|
|
||||||
inputs_embeds[image_token_mask] = image_features.view(
|
|
||||||
-1, image_features.shape[-1]
|
|
||||||
)
|
|
||||||
attention_mask = self.get_attention_mask(
|
|
||||||
input_ids,
|
|
||||||
cu_seqlen_prefill,
|
|
||||||
inputs_embeds.dtype,
|
|
||||||
)
|
|
||||||
# Use flash attention for text-only input
|
# Use flash attention for text-only input
|
||||||
# else:
|
# else:
|
||||||
# if cu_seqlen_prefill is not None:
|
# if cu_seqlen_prefill is not None:
|
||||||
|
@ -116,11 +116,10 @@ class MistralAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
if hasattr(config, "head_dim"):
|
if getattr(config, "head_dim", None) is not None:
|
||||||
self.head_size = config.head_dim
|
self.head_size = config.head_dim
|
||||||
else:
|
else:
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
config=config,
|
||||||
dim=self.head_size,
|
dim=self.head_size,
|
||||||
|
@ -62,10 +62,40 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
self.pad_token_id = (
|
self.pad_token_id = (
|
||||||
config.pad_token_id if config.pad_token_id is not None else -1
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
)
|
)
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
|
||||||
|
def get_vision_embeds(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
pixel_values = pixel_values.to(dtype=self.dtype)
|
||||||
|
image_outputs = self.vision_tower(pixel_values)
|
||||||
|
last_hidden_state = self.post_vision_tower_layernorm(
|
||||||
|
image_outputs.last_hidden_state
|
||||||
|
)
|
||||||
|
image_features = self.multi_modal_projector(last_hidden_state)
|
||||||
|
image_features = image_features.view(-1, image_features.shape[-1])
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def get_inputs_embeds(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
vision_embeds: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if vision_embeds is not None:
|
||||||
|
mask = input_ids == self.config.image_token_index
|
||||||
|
inputs_embeds[mask] = vision_embeds
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -75,33 +105,15 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
|
||||||
# Unused here
|
# Unused here
|
||||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
image_sizes: Optional[torch.Tensor] = None,
|
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
|
||||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
max_s += 1
|
max_s += 1
|
||||||
position_ids += 1
|
position_ids += 1
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
|
||||||
image_outputs = self.vision_tower(pixel_values)
|
|
||||||
last_hidden_state = self.post_vision_tower_layernorm(
|
|
||||||
image_outputs.last_hidden_state
|
|
||||||
)
|
|
||||||
image_features = self.multi_modal_projector(last_hidden_state)
|
|
||||||
|
|
||||||
# mask where image or padding tokens
|
|
||||||
mask = input_ids == self.config.image_token_index
|
|
||||||
|
|
||||||
# insert image features into input embeddings
|
|
||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
|
||||||
|
|
||||||
hidden_states = self.text_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -733,9 +733,93 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def get_vision_embeds(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
assert pixel_values is not None
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
all_states = []
|
||||||
|
all_pixel_values = pixel_values
|
||||||
|
all_pixel_mask = pixel_attention_mask
|
||||||
|
for i in range(batch_size):
|
||||||
|
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
pixel_values = pixel_values[i : i + 1]
|
||||||
|
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(
|
||||||
|
dim=(-1, -2, -3)
|
||||||
|
) != nb_values_per_image
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=(
|
||||||
|
pixel_values.size(0),
|
||||||
|
pixel_values.size(2),
|
||||||
|
pixel_values.size(3),
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask/pP p
|
||||||
|
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(
|
||||||
|
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||||
|
)
|
||||||
|
pixel_attention_mask = pixel_attention_mask[
|
||||||
|
real_images_inds
|
||||||
|
].contiguous()
|
||||||
|
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(
|
||||||
|
dimension=1, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(
|
||||||
|
dimension=2, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
patch_attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(
|
||||||
|
image_hidden_states,
|
||||||
|
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
||||||
|
)
|
||||||
|
all_states.append(image_hidden_states)
|
||||||
|
image_hidden_states = torch.stack(all_states, dim=0)
|
||||||
|
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||||
|
|
||||||
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
vision_embeds: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if vision_embeds is not None:
|
||||||
|
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||||
|
# that simply don't exist
|
||||||
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
|
input_ids, inputs_embeds, vision_embeds
|
||||||
|
)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -745,82 +829,10 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
|
||||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
|
||||||
# Unused here
|
# Unused here
|
||||||
image_sizes: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
):
|
):
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
|
||||||
if pixel_values is not None:
|
|
||||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
|
||||||
all_states = []
|
|
||||||
all_pixel_values = pixel_values
|
|
||||||
all_pixel_mask = pixel_attention_mask
|
|
||||||
for i in range(batch_size):
|
|
||||||
pixel_values = all_pixel_values.to(
|
|
||||||
dtype=self.dtype
|
|
||||||
) # fp16 compatibility
|
|
||||||
pixel_values = pixel_values[i : i + 1]
|
|
||||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
|
||||||
|
|
||||||
# Remove padding images - padding images are full 0.
|
|
||||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
|
||||||
real_images_inds = (pixel_values == 0.0).sum(
|
|
||||||
dim=(-1, -2, -3)
|
|
||||||
) != nb_values_per_image
|
|
||||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
|
||||||
|
|
||||||
# Handle the vision attention mask
|
|
||||||
if pixel_attention_mask is None:
|
|
||||||
pixel_attention_mask = torch.ones(
|
|
||||||
size=(
|
|
||||||
pixel_values.size(0),
|
|
||||||
pixel_values.size(2),
|
|
||||||
pixel_values.size(3),
|
|
||||||
),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=pixel_values.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Remove padding images from the mask/pP p
|
|
||||||
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
|
||||||
pixel_attention_mask = pixel_attention_mask.view(
|
|
||||||
1 * num_images, *pixel_attention_mask.shape[2:]
|
|
||||||
)
|
|
||||||
pixel_attention_mask = pixel_attention_mask[
|
|
||||||
real_images_inds
|
|
||||||
].contiguous()
|
|
||||||
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
|
||||||
patches_subgrid = pixel_attention_mask.unfold(
|
|
||||||
dimension=1, size=patch_size, step=patch_size
|
|
||||||
)
|
|
||||||
patches_subgrid = patches_subgrid.unfold(
|
|
||||||
dimension=2, size=patch_size, step=patch_size
|
|
||||||
)
|
|
||||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
|
||||||
image_hidden_states = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
patch_attention_mask=patch_attention_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Modality projection & resampling
|
|
||||||
image_hidden_states = self.connector(
|
|
||||||
image_hidden_states,
|
|
||||||
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
|
||||||
)
|
|
||||||
all_states.append(image_hidden_states)
|
|
||||||
image_hidden_states = torch.stack(all_states, dim=0)
|
|
||||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
|
||||||
# that simply don't exist
|
|
||||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
|
||||||
input_ids, inputs_embeds, image_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.text_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -476,9 +476,92 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def get_vision_embeds(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
all_states = []
|
||||||
|
all_pixel_values = pixel_values
|
||||||
|
all_pixel_mask = pixel_attention_mask
|
||||||
|
for i in range(batch_size):
|
||||||
|
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
pixel_values = pixel_values[i : i + 1]
|
||||||
|
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(
|
||||||
|
dim=(-1, -2, -3)
|
||||||
|
) != nb_values_per_image
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=(
|
||||||
|
pixel_values.size(0),
|
||||||
|
pixel_values.size(2),
|
||||||
|
pixel_values.size(3),
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask/pP p
|
||||||
|
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(
|
||||||
|
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||||
|
)
|
||||||
|
pixel_attention_mask = pixel_attention_mask[
|
||||||
|
real_images_inds
|
||||||
|
].contiguous()
|
||||||
|
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(
|
||||||
|
dimension=1, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(
|
||||||
|
dimension=2, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
patch_attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(
|
||||||
|
image_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_states.append(image_hidden_states)
|
||||||
|
image_hidden_states = torch.stack(all_states, dim=0)
|
||||||
|
|
||||||
|
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||||
|
|
||||||
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
vision_embeds: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if vision_embeds is not None:
|
||||||
|
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||||
|
# that simply don't exist
|
||||||
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
|
input_ids, inputs_embeds, vision_embeds
|
||||||
|
)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -488,83 +571,11 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
|
||||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
|
||||||
# Unused here
|
# Unused here
|
||||||
image_sizes: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
cross_attention_states: Optional[torch.Tensor] = None,
|
|
||||||
image_indices=None,
|
image_indices=None,
|
||||||
):
|
):
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
|
||||||
if pixel_values is not None:
|
|
||||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
|
||||||
all_states = []
|
|
||||||
all_pixel_values = pixel_values
|
|
||||||
all_pixel_mask = pixel_attention_mask
|
|
||||||
for i in range(batch_size):
|
|
||||||
pixel_values = all_pixel_values.to(
|
|
||||||
dtype=self.dtype
|
|
||||||
) # fp16 compatibility
|
|
||||||
pixel_values = pixel_values[i : i + 1]
|
|
||||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
|
||||||
|
|
||||||
# Remove padding images - padding images are full 0.
|
|
||||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
|
||||||
real_images_inds = (pixel_values == 0.0).sum(
|
|
||||||
dim=(-1, -2, -3)
|
|
||||||
) != nb_values_per_image
|
|
||||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
|
||||||
# Handle the vision attention mask
|
|
||||||
if pixel_attention_mask is None:
|
|
||||||
pixel_attention_mask = torch.ones(
|
|
||||||
size=(
|
|
||||||
pixel_values.size(0),
|
|
||||||
pixel_values.size(2),
|
|
||||||
pixel_values.size(3),
|
|
||||||
),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=pixel_values.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Remove padding images from the mask/pP p
|
|
||||||
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
|
||||||
pixel_attention_mask = pixel_attention_mask.view(
|
|
||||||
1 * num_images, *pixel_attention_mask.shape[2:]
|
|
||||||
)
|
|
||||||
pixel_attention_mask = pixel_attention_mask[
|
|
||||||
real_images_inds
|
|
||||||
].contiguous()
|
|
||||||
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
|
||||||
patches_subgrid = pixel_attention_mask.unfold(
|
|
||||||
dimension=1, size=patch_size, step=patch_size
|
|
||||||
)
|
|
||||||
patches_subgrid = patches_subgrid.unfold(
|
|
||||||
dimension=2, size=patch_size, step=patch_size
|
|
||||||
)
|
|
||||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
|
||||||
image_hidden_states = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
patch_attention_mask=patch_attention_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Modality projection & resampling
|
|
||||||
image_hidden_states = self.connector(
|
|
||||||
image_hidden_states,
|
|
||||||
)
|
|
||||||
|
|
||||||
all_states.append(image_hidden_states)
|
|
||||||
image_hidden_states = torch.stack(all_states, dim=0)
|
|
||||||
|
|
||||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
|
||||||
input_ids, inputs_embeds, image_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.text_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -163,9 +163,114 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def get_vision_embeds(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||||
|
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||||
|
# 1. Extract the input embeddings
|
||||||
|
|
||||||
|
# 2. Merge text and images
|
||||||
|
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||||
|
pixel_values = pixel_values.view(
|
||||||
|
num_images * num_patches, channels, height, width
|
||||||
|
)
|
||||||
|
image_features = self.vision_tower(pixel_values)
|
||||||
|
|
||||||
|
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||||
|
# Already done within the clip model
|
||||||
|
selected_image_feature = image_features.last_hidden_state
|
||||||
|
|
||||||
|
if self.config.vision_feature_select_strategy == "default":
|
||||||
|
selected_image_feature = selected_image_feature[:, 1:]
|
||||||
|
elif self.config.vision_feature_select_strategy == "full":
|
||||||
|
selected_image_feature = selected_image_feature
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||||
|
)
|
||||||
|
|
||||||
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
|
|
||||||
|
# split up image_features for each of the individual images
|
||||||
|
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||||
|
# if we assume each image has 5 image features (base image + 4 patches)
|
||||||
|
split_sizes = [num_patches] * num_images
|
||||||
|
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||||
|
|
||||||
|
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||||
|
height = width = (
|
||||||
|
self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
new_image_features = []
|
||||||
|
for image_idx, image_feature in enumerate(image_features):
|
||||||
|
if image_feature.shape[0] > 1:
|
||||||
|
base_image_feature = image_feature[0]
|
||||||
|
image_feature = image_feature[1:]
|
||||||
|
|
||||||
|
if height * width != base_image_feature.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
"The number of patches is not consistent with the image size."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dimensions are intentionally swapped to be bug-compatible with
|
||||||
|
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||||
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||||
|
image_sizes[image_idx],
|
||||||
|
self.config.image_grid_pinpoints,
|
||||||
|
self.config.vision_config.image_size,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.view(
|
||||||
|
num_patch_height, num_patch_width, height, width, -1
|
||||||
|
)
|
||||||
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||||
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||||
|
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(
|
||||||
|
image_feature,
|
||||||
|
self.image_newline[:, None, None].expand(
|
||||||
|
*image_feature.shape[:-1], 1
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||||
|
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||||
|
else:
|
||||||
|
image_feature = image_feature[0]
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(image_feature, self.image_newline[None]), dim=0
|
||||||
|
)
|
||||||
|
new_image_features.append(image_feature)
|
||||||
|
image_features = torch.stack(new_image_features, dim=0)
|
||||||
|
return image_features.view(-1, image_features.shape[-1])
|
||||||
|
|
||||||
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
vision_embeds: torch.Tensor = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if vision_embeds is not None:
|
||||||
|
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||||
|
# that simply don't exist
|
||||||
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
|
input_ids, inputs_embeds, vision_embeds
|
||||||
|
)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -175,102 +280,10 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
|
||||||
# Unused for this model
|
# Unused for this model
|
||||||
pixel_attention_mask=None,
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
image_sizes: Optional[torch.LongTensor] = None,
|
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
):
|
):
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
|
||||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
|
||||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
|
||||||
# 1. Extract the input embeddings
|
|
||||||
|
|
||||||
# 2. Merge text and images
|
|
||||||
num_images, num_patches, channels, height, width = pixel_values.shape
|
|
||||||
pixel_values = pixel_values.view(
|
|
||||||
num_images * num_patches, channels, height, width
|
|
||||||
)
|
|
||||||
image_features = self.vision_tower(pixel_values)
|
|
||||||
|
|
||||||
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
|
||||||
# Already done within the clip model
|
|
||||||
selected_image_feature = image_features.last_hidden_state
|
|
||||||
|
|
||||||
if self.config.vision_feature_select_strategy == "default":
|
|
||||||
selected_image_feature = selected_image_feature[:, 1:]
|
|
||||||
elif self.config.vision_feature_select_strategy == "full":
|
|
||||||
selected_image_feature = selected_image_feature
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
|
||||||
)
|
|
||||||
|
|
||||||
image_features = self.multi_modal_projector(selected_image_feature)
|
|
||||||
|
|
||||||
# split up image_features for each of the individual images
|
|
||||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
|
||||||
# if we assume each image has 5 image features (base image + 4 patches)
|
|
||||||
split_sizes = [num_patches] * num_images
|
|
||||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
|
||||||
|
|
||||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
|
||||||
height = width = (
|
|
||||||
self.config.vision_config.image_size
|
|
||||||
// self.config.vision_config.patch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
new_image_features = []
|
|
||||||
for image_idx, image_feature in enumerate(image_features):
|
|
||||||
if image_feature.shape[0] > 1:
|
|
||||||
base_image_feature = image_feature[0]
|
|
||||||
image_feature = image_feature[1:]
|
|
||||||
|
|
||||||
if height * width != base_image_feature.shape[0]:
|
|
||||||
raise ValueError(
|
|
||||||
"The number of patches is not consistent with the image size."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dimensions are intentionally swapped to be bug-compatible with
|
|
||||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
|
||||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
|
||||||
image_sizes[image_idx],
|
|
||||||
self.config.image_grid_pinpoints,
|
|
||||||
self.config.vision_config.image_size,
|
|
||||||
)
|
|
||||||
image_feature = image_feature.view(
|
|
||||||
num_patch_height, num_patch_width, height, width, -1
|
|
||||||
)
|
|
||||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
|
||||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
||||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
|
||||||
image_feature = torch.cat(
|
|
||||||
(
|
|
||||||
image_feature,
|
|
||||||
self.image_newline[:, None, None].expand(
|
|
||||||
*image_feature.shape[:-1], 1
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
||||||
image_feature = torch.cat(
|
|
||||||
(base_image_feature, image_feature), dim=0
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_feature = image_feature[0]
|
|
||||||
image_feature = torch.cat(
|
|
||||||
(image_feature, self.image_newline[None]), dim=0
|
|
||||||
)
|
|
||||||
new_image_features.append(image_feature)
|
|
||||||
image_features = torch.stack(new_image_features, dim=0)
|
|
||||||
|
|
||||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
|
||||||
input_ids, inputs_embeds, image_features
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.text_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -922,9 +922,32 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
return position_ids
|
return position_ids
|
||||||
|
|
||||||
def forward(
|
def get_vision_embeds(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
vision_embeds: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# apply the visual model to the pixel values if they are provided
|
||||||
|
if vision_embeds is not None:
|
||||||
|
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -934,26 +957,11 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor],
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
pixel_values: torch.FloatTensor = None,
|
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
# Unused in this model
|
# Unused in this model
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
pixel_attention_mask=None,
|
|
||||||
image_sizes: Optional[torch.LongTensor] = None,
|
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
cross_attention_states: Optional[torch.Tensor] = None,
|
|
||||||
image_indices=None,
|
image_indices=None,
|
||||||
):
|
):
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
# apply the visual model to the pixel values if they are provided
|
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
|
||||||
if pixel_values is not None:
|
|
||||||
image_embeds = self.visual(
|
|
||||||
pixel_values, grid_thw=image_grid_thw
|
|
||||||
).squeeze(0)
|
|
||||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
|
||||||
|
|
||||||
hidden_states = self.text_model(
|
hidden_states = self.text_model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -500,9 +500,32 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
return position_ids
|
return position_ids
|
||||||
|
|
||||||
def forward(
|
def get_vision_embeds(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
vision_embeds: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# apply the visual model to the pixel values if they are provided
|
||||||
|
if vision_embeds is not None:
|
||||||
|
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -512,25 +535,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor],
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
pixel_values: torch.FloatTensor = None,
|
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
pixel_attention_mask=None,
|
|
||||||
image_sizes: Optional[torch.LongTensor] = None,
|
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
cross_attention_states: Optional[torch.Tensor] = None,
|
|
||||||
image_indices=None,
|
image_indices=None,
|
||||||
|
attention_mask=None,
|
||||||
):
|
):
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
# apply the visual model to the pixel values if they are provided
|
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
|
||||||
if pixel_values is not None:
|
|
||||||
image_embeds = self.visual(
|
|
||||||
pixel_values, grid_thw=image_grid_thw
|
|
||||||
).squeeze(0)
|
|
||||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
|
||||||
|
|
||||||
hidden_states = self.text_model(
|
hidden_states = self.text_model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -1298,7 +1298,7 @@ class FlashCausalLM(Model):
|
|||||||
if head_size is None:
|
if head_size is None:
|
||||||
# Some models use GQA and different sizes for o_proj
|
# Some models use GQA and different sizes for o_proj
|
||||||
# and q_proj, that allows for that.
|
# and q_proj, that allows for that.
|
||||||
if hasattr(config, "head_dim"):
|
if getattr(config, "head_dim", None) is not None:
|
||||||
self.head_size = config.head_dim
|
self.head_size = config.head_dim
|
||||||
else:
|
else:
|
||||||
self.head_size = config.hidden_size // config.num_attention_heads
|
self.head_size = config.hidden_size // config.num_attention_heads
|
||||||
@ -1896,6 +1896,9 @@ class FlashCausalLM(Model):
|
|||||||
if prefill:
|
if prefill:
|
||||||
batch.prepare_for_prefill()
|
batch.prepare_for_prefill()
|
||||||
|
|
||||||
|
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
|
||||||
|
self.set_inputs_embeds(batch)
|
||||||
|
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
# Update adapter indices for speculative tokens (if present)
|
# Update adapter indices for speculative tokens (if present)
|
||||||
|
@ -29,10 +29,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
aspect_ratio_mask: Optional[torch.Tensor] = None
|
aspect_ratio_mask: Optional[torch.Tensor] = None
|
||||||
cross_attention_states: Optional[torch.Tensor] = None
|
cross_attention_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def prepare_for_prefill(self):
|
||||||
|
super(VlmCausalLMBatch, self).prepare_for_prefill()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches):
|
def concatenate(cls, batches):
|
||||||
batch = super().concatenate(batches)
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
|
|
||||||
@ -196,6 +199,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
|
|
||||||
|
|
||||||
class MllamaCausalLM(VlmCausalLM):
|
class MllamaCausalLM(VlmCausalLM):
|
||||||
|
def set_inputs_embeds(self, batch):
|
||||||
|
# Set the input embeddings to None, as we are using the input_ids for the model
|
||||||
|
batch.inputs_embeds = None
|
||||||
|
|
||||||
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||||
|
super(VlmCausalLM, self).cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: MllamaCausalLMBatch,
|
batch: MllamaCausalLMBatch,
|
||||||
|
@ -1,71 +0,0 @@
|
|||||||
from io import BytesIO
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Iterable
|
|
||||||
from text_generation_server.models.vlm_causal_lm import (
|
|
||||||
VlmCausalLMBatch,
|
|
||||||
image_text_replacement,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.pb.generate_pb2 import Request
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaBatch(VlmCausalLMBatch):
|
|
||||||
@classmethod
|
|
||||||
def batch_tokenized_inputs(
|
|
||||||
cls, requests: Iterable[Request], tokenizer, processor, config
|
|
||||||
):
|
|
||||||
batch_inputs = []
|
|
||||||
image_inputs = []
|
|
||||||
max_truncation = 0
|
|
||||||
for r in requests:
|
|
||||||
full_text = ""
|
|
||||||
image_id = 0
|
|
||||||
for chunk in r.input_chunks.chunks:
|
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
|
||||||
if chunk_type == "text":
|
|
||||||
full_text += "<bos>" + chunk.text + "\n"
|
|
||||||
elif chunk_type == "image":
|
|
||||||
image = Image.open(BytesIO(chunk.image.data))
|
|
||||||
# TODO do_convert_RGB should be on by default ?
|
|
||||||
image = image.convert("RGB")
|
|
||||||
image_input = processor.image_processor(image, return_tensors="pt")
|
|
||||||
full_text += image_text_replacement(
|
|
||||||
processor, image_input, config, image_id
|
|
||||||
)
|
|
||||||
image_inputs.append(image_input)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
|
||||||
|
|
||||||
batch_inputs.append(full_text)
|
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
|
||||||
batch_inputs,
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_truncation,
|
|
||||||
add_special_tokens=False,
|
|
||||||
)["input_ids"]
|
|
||||||
if image_inputs:
|
|
||||||
image_input = image_inputs[0]
|
|
||||||
new_image_inputs = {
|
|
||||||
"pixel_values": torch.cat(
|
|
||||||
[img["pixel_values"] for img in image_inputs], dim=0
|
|
||||||
),
|
|
||||||
}
|
|
||||||
if "pixel_attention_mask" in image_input:
|
|
||||||
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
|
||||||
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
|
||||||
)
|
|
||||||
if "image_sizes" in image_input:
|
|
||||||
new_image_inputs["image_sizes"] = torch.cat(
|
|
||||||
[img["image_sizes"] for img in image_inputs], dim=0
|
|
||||||
)
|
|
||||||
image_inputs = new_image_inputs
|
|
||||||
else:
|
|
||||||
image_inputs = None
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
|
@ -163,6 +163,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
processor_kwargs=None,
|
processor_kwargs=None,
|
||||||
kv_cache_dtype: Optional[torch.dtype] = None,
|
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||||
batch_class=VlmCausalLMBatch,
|
batch_class=VlmCausalLMBatch,
|
||||||
|
support_chunking: bool = True,
|
||||||
):
|
):
|
||||||
self.batch_class = batch_class
|
self.batch_class = batch_class
|
||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
@ -304,6 +305,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
support_chunking=support_chunking,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
|
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
|
||||||
@ -338,6 +340,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
batch_class: Optional[type] = VlmCausalLMBatch,
|
batch_class: Optional[type] = VlmCausalLMBatch,
|
||||||
processor_kwargs: Optional[dict] = None,
|
processor_kwargs: Optional[dict] = None,
|
||||||
|
support_chunking: bool = True,
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -349,6 +352,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
batch_class=batch_class,
|
batch_class=batch_class,
|
||||||
processor_kwargs=processor_kwargs,
|
processor_kwargs=processor_kwargs,
|
||||||
|
support_chunking=support_chunking,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _model_forward(
|
def _model_forward(
|
||||||
@ -368,6 +372,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
pixel_attention_mask=None,
|
pixel_attention_mask=None,
|
||||||
image_sizes: Optional[torch.LongTensor] = None,
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
):
|
):
|
||||||
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
||||||
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
||||||
@ -377,9 +382,12 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
)
|
)
|
||||||
|
inputs["input_ids"] = None
|
||||||
|
|
||||||
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
||||||
logits = self.model.original_forward(
|
logits = self.model.original_forward(
|
||||||
input_ids=inputs["input_ids"],
|
input_ids=inputs["input_ids"],
|
||||||
|
inputs_embeds=inputs_embeds.unsqueeze(0),
|
||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||||
@ -568,3 +576,48 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
|||||||
inputs["cache_position"] = position_ids
|
inputs["cache_position"] = position_ids
|
||||||
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
def get_vision_embeds(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
image_features = self.model.get_image_features(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
vision_feature_layer=self.model.config.vision_config.vision_feature_layer,
|
||||||
|
vision_feature_select_strategy=self.model.config.vision_config.vision_feature_select_strategy,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||||
|
projected_vision_flat = self.model.multi_modal_projector(vision_flat)
|
||||||
|
return projected_vision_flat
|
||||||
|
|
||||||
|
def get_inputs_embeds(self, input_ids, vision_embeds=None):
|
||||||
|
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
if vision_embeds is not None:
|
||||||
|
original_inputs_embeds_shape = inputs_embeds.shape
|
||||||
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
final_mask = special_image_mask.to(inputs_embeds.device)
|
||||||
|
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
|
||||||
|
|
||||||
|
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||||
|
num_tokens_to_fill = final_mask_1d.sum()
|
||||||
|
|
||||||
|
if num_tokens_to_fill != vision_embeds.size(0):
|
||||||
|
raise ValueError(
|
||||||
|
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||||
|
f"but multi_modal_projector returned {vision_embeds.size(0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
|
||||||
|
-1, inputs_embeds.size(-1)
|
||||||
|
)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)
|
||||||
|
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
||||||
|
return inputs_embeds
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -12,7 +13,7 @@ from text_generation_server.models.flash_causal_lm import (
|
|||||||
FlashCausalLMBatch,
|
FlashCausalLMBatch,
|
||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION, MEM_POOL
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
@ -109,17 +110,17 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|||||||
return height // patch_size, width // patch_size
|
return height // patch_size, width // patch_size
|
||||||
|
|
||||||
|
|
||||||
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
|
def image_text_replacement(processor, image_input, config) -> str:
|
||||||
if config.model_type == "idefics2":
|
if config.model_type == "idefics2":
|
||||||
image_seq_len = 64
|
image_seq_len = 64
|
||||||
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
||||||
if processor.image_processor.do_image_splitting:
|
if processor.image_processor.do_image_splitting:
|
||||||
image_str *= 5
|
image_str *= 5
|
||||||
return image_str
|
return image_str, IDEFICS2_FAKE_TOKEN
|
||||||
if config.model_type == "idefics3":
|
if config.model_type == "idefics3":
|
||||||
# TODO: implement this in a more general way
|
# TODO: implement this in a more general way
|
||||||
n_rows = image_input["rows"][0][image_id]
|
n_rows = image_input["rows"][0][0]
|
||||||
n_cols = image_input["cols"][0][image_id]
|
n_cols = image_input["cols"][0][0]
|
||||||
image_seq_len = int(
|
image_seq_len = int(
|
||||||
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
|
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
|
||||||
/ (config.scale_factor**2)
|
/ (config.scale_factor**2)
|
||||||
@ -132,41 +133,41 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
image_token=IDEFICS3_IMAGE_TOKEN,
|
image_token=IDEFICS3_IMAGE_TOKEN,
|
||||||
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
||||||
)
|
)
|
||||||
return image_str
|
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
|
||||||
elif config.model_type == "llava_next":
|
elif config.model_type == "llava_next":
|
||||||
height, width = image_input["image_sizes"][image_id]
|
height, width = image_input["image_sizes"][0]
|
||||||
num_features = get_number_of_features(height, width, config)
|
num_features = get_number_of_features(height, width, config)
|
||||||
|
|
||||||
log_master(
|
log_master(
|
||||||
logger.info,
|
logger.info,
|
||||||
f"Found {num_features} features in image of resolution {height}x{width}",
|
f"Found {num_features} features in image of resolution {height}x{width}",
|
||||||
)
|
)
|
||||||
return "<image>" * num_features
|
return "<image>" * num_features, "<image>"
|
||||||
|
|
||||||
elif config.model_type == "paligemma":
|
elif config.model_type == "paligemma":
|
||||||
return "<image>" * config.text_config.num_image_tokens
|
return "<image>" * config.text_config.num_image_tokens, "<image>"
|
||||||
elif config.model_type == "qwen2_vl":
|
elif config.model_type == "qwen2_vl":
|
||||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
|
||||||
num_pads = grid_t * grid_h * grid_w // 4
|
num_pads = grid_t * grid_h * grid_w // 4
|
||||||
padding = "<|image_pad|>" * num_pads
|
padding = "<|image_pad|>" * num_pads
|
||||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||||
elif config.model_type == "qwen2_5_vl":
|
elif config.model_type == "qwen2_5_vl":
|
||||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
|
||||||
num_pads = grid_t * grid_h * grid_w // 4
|
num_pads = grid_t * grid_h * grid_w // 4
|
||||||
padding = "<|image_pad|>" * num_pads
|
padding = "<|image_pad|>" * num_pads
|
||||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||||
elif config.model_type == "gemma3":
|
elif config.model_type == "gemma3":
|
||||||
# TODO: get correct number of features via reviewing the Gemma3 architecture
|
# TODO: get correct number of features via reviewing the Gemma3 architecture
|
||||||
# and calculating the number of image tokens
|
# and calculating the number of image tokens
|
||||||
num_pads = 256
|
num_pads = 256
|
||||||
padding = "<image_soft_token>" * num_pads
|
padding = "<image_soft_token>" * num_pads
|
||||||
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
|
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n", "<start_of_image>"
|
||||||
elif config.model_type == "llama4":
|
elif config.model_type == "llama4":
|
||||||
patch_size = config.vision_config.patch_size
|
patch_size = config.vision_config.patch_size
|
||||||
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
|
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
|
||||||
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
|
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
|
||||||
aspect_ratios = image_input["aspect_ratios"][image_id]
|
aspect_ratios = image_input["aspect_ratios"][0]
|
||||||
image_height, image_width = image_input["pixel_values"][image_id].shape[-2:]
|
image_height, image_width = image_input["pixel_values"][0].shape[-2:]
|
||||||
|
|
||||||
num_patches_per_chunk = int(
|
num_patches_per_chunk = int(
|
||||||
(image_height // patch_size)
|
(image_height // patch_size)
|
||||||
@ -177,7 +178,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
aspect_ratios, num_patches_per_chunk
|
aspect_ratios, num_patches_per_chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
return tokens_for_this_image
|
return tokens_for_this_image, "<|image_start|>"
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
@ -190,6 +191,27 @@ def image_text_replacement_fixup(config, text: str) -> str:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_text(config, text: str) -> str:
|
||||||
|
if config.model_type == "paligemma":
|
||||||
|
return "<bos>" + text + "\n"
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(config, img):
|
||||||
|
model_type = config.model_type
|
||||||
|
|
||||||
|
if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
|
||||||
|
img = img.resize((img.width * 2, img.height * 2))
|
||||||
|
if model_type == "paligemma":
|
||||||
|
img = img.convert("RGB")
|
||||||
|
|
||||||
|
if model_type not in {"llava_next", "gemma3", "llama4"}:
|
||||||
|
# TODO: check if this is needed
|
||||||
|
img = [img]
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
def get_unpadded_features(
|
def get_unpadded_features(
|
||||||
original_height: int,
|
original_height: int,
|
||||||
original_width: int,
|
original_width: int,
|
||||||
@ -244,105 +266,263 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||||||
return unpadded_features + newline_features + base_features
|
return unpadded_features + newline_features + base_features
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_image_embeds(
|
||||||
|
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if is_embed is None:
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
placeholders = embeds.new_full(
|
||||||
|
(is_embed.shape[0], embeds.shape[-1]),
|
||||||
|
fill_value=torch.nan,
|
||||||
|
)
|
||||||
|
placeholders[is_embed] = embeds
|
||||||
|
return placeholders
|
||||||
|
|
||||||
|
|
||||||
|
def gather_image_embeds(
|
||||||
|
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
if is_embed is None:
|
||||||
|
return embeds
|
||||||
|
sel = embeds[is_embed]
|
||||||
|
return sel if sel.numel() else None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImagePositions:
|
||||||
|
offset: int
|
||||||
|
length: int
|
||||||
|
id: int
|
||||||
|
num_placeholder_tokens: int
|
||||||
|
is_embed: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLMBatch(FlashCausalLMBatch):
|
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
|
image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]]
|
||||||
|
image_positions: Optional[List[List[ImagePositions]]]
|
||||||
|
encoder_cache: Optional[List[Dict[int, torch.Tensor]]]
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
image_sizes: Optional[List[Tuple[int, int]]]
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
image_grid_thw: Optional[torch.Tensor]
|
image_grid_thw: Optional[torch.Tensor]
|
||||||
|
cache_entries_to_free: List[Tuple[int, int]]
|
||||||
|
has_image_inputs: bool = False
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches):
|
def concatenate(cls, batches):
|
||||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||||
|
|
||||||
|
batch.image_inputs = []
|
||||||
|
batch.image_positions = []
|
||||||
|
batch.encoder_cache = []
|
||||||
|
for b in batches:
|
||||||
|
if b.image_inputs is not None:
|
||||||
|
batch.image_inputs.extend(b.image_inputs)
|
||||||
|
else:
|
||||||
|
batch.image_inputs.append(None)
|
||||||
|
if b.image_positions is not None:
|
||||||
|
batch.image_positions.extend(b.image_positions)
|
||||||
|
else:
|
||||||
|
batch.image_positions.append(None)
|
||||||
|
if b.encoder_cache is not None:
|
||||||
|
batch.encoder_cache.extend(b.encoder_cache)
|
||||||
|
else:
|
||||||
|
batch.encoder_cache.append(None)
|
||||||
|
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
batch.inputs_embeds = None
|
||||||
|
|
||||||
|
# To be filled in prepare_for_prefill
|
||||||
|
batch.has_image_inputs = False
|
||||||
|
batch.cache_entries_to_free = []
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]):
|
def filter(self, request_ids: List[int]):
|
||||||
|
if len(request_ids) == 0:
|
||||||
|
raise ValueError("Batch must have at least one request")
|
||||||
|
|
||||||
|
image_inputs = []
|
||||||
|
image_positions = []
|
||||||
|
encoder_cache = []
|
||||||
|
|
||||||
|
for request_id in request_ids:
|
||||||
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
image_inputs.append(self.image_inputs[idx])
|
||||||
|
image_positions.append(self.image_positions[idx])
|
||||||
|
encoder_cache.append(self.encoder_cache[idx])
|
||||||
|
|
||||||
batch = super().filter(request_ids)
|
batch = super().filter(request_ids)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
batch.inputs_embeds = None
|
||||||
|
|
||||||
|
batch.image_inputs = image_inputs
|
||||||
|
batch.image_positions = image_positions
|
||||||
|
batch.encoder_cache = encoder_cache
|
||||||
|
|
||||||
|
# To be filled in prepare_for_prefill
|
||||||
|
batch.has_image_inputs = False
|
||||||
|
batch.cache_entries_to_free = []
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(
|
def batch_tokenized_inputs(
|
||||||
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
||||||
):
|
):
|
||||||
# Process images first. We need all of them so that the processor
|
kwargs = {}
|
||||||
# can make the image splits the same size. And we need the final
|
if (
|
||||||
# sizes to insert correct number of image tokens.
|
hasattr(processor, "image_processor_class")
|
||||||
images = []
|
and processor.image_processor_class == "Idefics3ImageProcessor"
|
||||||
|
):
|
||||||
|
kwargs["return_row_col_info"] = True
|
||||||
|
|
||||||
|
max_length = 0
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
|
||||||
|
if not hasattr(config, "image_token_index"):
|
||||||
|
config.image_token_index = config.image_token_id
|
||||||
|
|
||||||
|
batch_tokenized_inputs: List[List[int]] = []
|
||||||
|
batch_image_inputs: List[Optional[List[dict]]] = []
|
||||||
|
batch_image_positions: List[Optional[List[ImagePositions]]] = []
|
||||||
|
|
||||||
for r in requests:
|
for r in requests:
|
||||||
|
text_parts = []
|
||||||
|
image_inputs = []
|
||||||
|
image_texts = []
|
||||||
|
|
||||||
|
image_id = 0
|
||||||
|
|
||||||
for chunk in r.input_chunks.chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
if chunk_type == "text":
|
if chunk_type == "text":
|
||||||
pass
|
text = preprocess_text(config, chunk.text)
|
||||||
|
text_parts.append(text)
|
||||||
elif chunk_type == "image":
|
elif chunk_type == "image":
|
||||||
image = Image.open(BytesIO(chunk.image.data))
|
img = Image.open(BytesIO(chunk.image.data))
|
||||||
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
|
img = preprocess_image(config, img)
|
||||||
# default warmup image is 20x20
|
|
||||||
if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
|
|
||||||
if image.width <= 20:
|
|
||||||
w = image.width * 2
|
|
||||||
h = image.height * 2
|
|
||||||
image = image.resize((w, h))
|
|
||||||
|
|
||||||
if config.model_type == "llava_next":
|
image_input = processor.image_processor(
|
||||||
images.append(image)
|
[img], return_tensors="pt", **kwargs
|
||||||
elif config.model_type == "gemma3":
|
)
|
||||||
images.append(image)
|
image_inputs.append(image_input)
|
||||||
elif config.model_type == "llama4":
|
|
||||||
images.append(image)
|
img_text, img_start_token_str = image_text_replacement(
|
||||||
else:
|
processor, image_input, config
|
||||||
images.append([image])
|
)
|
||||||
|
text_parts.append(img_text)
|
||||||
|
|
||||||
|
image_texts.append([image_id, img_start_token_str, img_text])
|
||||||
|
image_id += 1
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
if images:
|
full_text = image_text_replacement_fixup(config, "".join(text_parts))
|
||||||
kwargs = {}
|
|
||||||
if (
|
|
||||||
hasattr(processor, "image_processor_class")
|
|
||||||
and processor.image_processor_class == "Idefics3ImageProcessor"
|
|
||||||
):
|
|
||||||
kwargs["return_row_col_info"] = True
|
|
||||||
|
|
||||||
image_inputs = processor.image_processor(
|
|
||||||
images, return_tensors="pt", **kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_inputs = None
|
|
||||||
|
|
||||||
batch_tokenized_inputs = []
|
|
||||||
max_length = 0
|
|
||||||
image_id = 0
|
|
||||||
for r in requests:
|
|
||||||
full_text = ""
|
|
||||||
for chunk in r.input_chunks.chunks:
|
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
|
||||||
if chunk_type == "text":
|
|
||||||
full_text += chunk.text
|
|
||||||
elif chunk_type == "image":
|
|
||||||
full_text += image_text_replacement(
|
|
||||||
processor, image_inputs, config, image_id
|
|
||||||
)
|
|
||||||
image_id += 1
|
|
||||||
# from pdb import set_trace; set_trace()
|
|
||||||
full_text = image_text_replacement_fixup(config, full_text)
|
|
||||||
input_ids = tokenizer(
|
input_ids = tokenizer(
|
||||||
full_text,
|
full_text,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=r.truncate,
|
max_length=r.truncate,
|
||||||
add_special_tokens=r.add_special_tokens,
|
add_special_tokens=(
|
||||||
|
r.add_special_tokens if config.model_type != "paligemma" else False
|
||||||
|
),
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
max_length = max(max_length, len(input_ids))
|
max_length = max(max_length, len(input_ids))
|
||||||
batch_tokenized_inputs.append(input_ids)
|
|
||||||
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
if len(image_inputs) > 0:
|
||||||
|
img_start_token = vocab[image_texts[0][1]]
|
||||||
|
image_positions = cls.get_image_positions(
|
||||||
|
input_ids, image_texts, img_start_token, config, tokenizer
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_inputs = None
|
||||||
|
image_positions = None
|
||||||
|
|
||||||
|
batch_tokenized_inputs.append(input_ids)
|
||||||
|
batch_image_inputs.append(image_inputs)
|
||||||
|
batch_image_positions.append(image_positions)
|
||||||
|
|
||||||
|
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_image_positions(
|
||||||
|
cls,
|
||||||
|
input_ids: List[int],
|
||||||
|
image_texts: List[Tuple[int, str, str]],
|
||||||
|
img_start_token: int,
|
||||||
|
config,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> List[ImagePositions]:
|
||||||
|
image_positions = []
|
||||||
|
num_images = len(image_texts)
|
||||||
|
|
||||||
|
input_ids_t = torch.as_tensor(input_ids)
|
||||||
|
img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
|
||||||
|
num_tokens = input_ids_t.numel()
|
||||||
|
|
||||||
|
last_pos = 0
|
||||||
|
for i in range(num_images):
|
||||||
|
image_id, img_start_token_str, img_text = image_texts[i]
|
||||||
|
img_text = image_text_replacement_fixup(config, img_text)
|
||||||
|
|
||||||
|
if config.model_type == "gemma3":
|
||||||
|
img_text = img_text.replace("\n\n", "")
|
||||||
|
|
||||||
|
tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[
|
||||||
|
"input_ids"
|
||||||
|
][0]
|
||||||
|
length = tokens.numel()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
length <= num_tokens
|
||||||
|
), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens"
|
||||||
|
|
||||||
|
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
|
||||||
|
index = img_start_token_pos[pos]
|
||||||
|
assert torch.equal(
|
||||||
|
input_ids_t[index : index + length], tokens
|
||||||
|
), "Image tokens not found in input_ids"
|
||||||
|
|
||||||
|
is_embed = tokens == config.image_token_index
|
||||||
|
num_placeholder_tokens = int(is_embed.sum())
|
||||||
|
if num_placeholder_tokens == length:
|
||||||
|
is_embed = None
|
||||||
|
|
||||||
|
pos = ImagePositions(
|
||||||
|
offset=index,
|
||||||
|
length=length,
|
||||||
|
id=image_id,
|
||||||
|
num_placeholder_tokens=num_placeholder_tokens,
|
||||||
|
is_embed=is_embed,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_positions.append(pos)
|
||||||
|
last_pos = index + length
|
||||||
|
|
||||||
|
if (
|
||||||
|
config.model_type == "idefics2"
|
||||||
|
and i + 1 != num_images
|
||||||
|
and input_ids[last_pos] == config.image_token_index
|
||||||
|
):
|
||||||
|
fake_token = last_pos - 1
|
||||||
|
fake_token_index = torch.searchsorted(
|
||||||
|
img_start_token_pos, fake_token, right=False
|
||||||
|
)
|
||||||
|
img_start_token_pos[fake_token_index] = last_pos
|
||||||
|
image_texts[i + 1][2] = image_texts[i + 1][2][
|
||||||
|
len(img_start_token_str) :
|
||||||
|
]
|
||||||
|
|
||||||
|
return image_positions
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb_processor(
|
def from_pb_processor(
|
||||||
@ -354,33 +534,164 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "VlmCausalLMBatch":
|
) -> "VlmCausalLMBatch":
|
||||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
batch_tokenized_inputs, image_inputs, image_positions = (
|
||||||
pb.requests, tokenizer, processor, config
|
cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
|
||||||
)
|
)
|
||||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
if image_inputs is not None:
|
batch.image_inputs = image_inputs
|
||||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
batch.image_positions = image_positions
|
||||||
if "pixel_attention_mask" in image_inputs:
|
batch.encoder_cache = [{} for _ in range(len(pb.requests))]
|
||||||
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
if len(image_inputs):
|
||||||
device=device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch.pixel_attention_mask = None
|
|
||||||
if "image_sizes" in image_inputs:
|
|
||||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
|
||||||
else:
|
|
||||||
batch.image_sizes = None
|
|
||||||
if "image_grid_thw" in image_inputs:
|
|
||||||
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
|
||||||
else:
|
|
||||||
batch.image_grid_thw = None
|
|
||||||
else:
|
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
def prepare_for_prefill(self):
|
||||||
|
super().prepare_for_prefill()
|
||||||
|
|
||||||
|
self.has_image_inputs = False
|
||||||
|
self.cache_entries_to_free = []
|
||||||
|
|
||||||
|
self.pixel_values = []
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(self.cache_lengths)
|
||||||
|
== len(self.input_lengths)
|
||||||
|
== len(self.prefilling_mask)
|
||||||
|
), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask"
|
||||||
|
|
||||||
|
for i, (
|
||||||
|
cache_length,
|
||||||
|
input_length,
|
||||||
|
request_prefilling,
|
||||||
|
) in enumerate(
|
||||||
|
zip(
|
||||||
|
self.cache_lengths,
|
||||||
|
self.input_lengths,
|
||||||
|
self.prefilling_mask,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if not request_prefilling or self.image_positions[i] is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for image_position in self.image_positions[i]:
|
||||||
|
if image_position is None:
|
||||||
|
continue
|
||||||
|
start_pos = image_position.offset
|
||||||
|
length = image_position.length
|
||||||
|
|
||||||
|
if start_pos >= cache_length + input_length:
|
||||||
|
# No encoder input required at this step
|
||||||
|
break
|
||||||
|
if start_pos + length <= cache_length:
|
||||||
|
# The encode input is already processed
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.has_image_inputs = True
|
||||||
|
|
||||||
|
if image_position.id not in self.encoder_cache[i]:
|
||||||
|
image_inputs = self.image_inputs[i][image_position.id]
|
||||||
|
self.pixel_values.append((i, image_position.id, image_inputs))
|
||||||
|
|
||||||
|
# Remove the image from the image_inputs
|
||||||
|
self.image_inputs[i][image_position.id] = None
|
||||||
|
|
||||||
|
if not self.has_image_inputs:
|
||||||
|
self.pixel_values = None
|
||||||
|
self.pixel_attention_mask = None
|
||||||
|
self.image_sizes = None
|
||||||
|
self.image_grid_thw = None
|
||||||
|
else:
|
||||||
|
image_grid_thw_list = [
|
||||||
|
x[2]["image_grid_thw"]
|
||||||
|
for x in self.pixel_values
|
||||||
|
if "image_grid_thw" in x[2]
|
||||||
|
]
|
||||||
|
if image_grid_thw_list:
|
||||||
|
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to(
|
||||||
|
self.input_ids.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.image_grid_thw = None
|
||||||
|
|
||||||
|
def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
|
||||||
|
self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
|
||||||
|
encoder_outputs, img_pos.is_embed
|
||||||
|
)
|
||||||
|
|
||||||
|
def gather_vision_embeds(self):
|
||||||
|
device = self.input_ids.device
|
||||||
|
chunks = []
|
||||||
|
for (
|
||||||
|
i,
|
||||||
|
cache_length,
|
||||||
|
input_length,
|
||||||
|
request_prefilling,
|
||||||
|
) in zip(
|
||||||
|
range(len(self.requests)),
|
||||||
|
self.cache_lengths,
|
||||||
|
self.input_lengths,
|
||||||
|
self.prefilling_mask,
|
||||||
|
):
|
||||||
|
if not request_prefilling or self.image_positions[i] is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for image_position in self.image_positions[i]:
|
||||||
|
if image_position is None:
|
||||||
|
continue
|
||||||
|
start_pos = image_position.offset
|
||||||
|
length = image_position.length
|
||||||
|
|
||||||
|
if start_pos >= cache_length + input_length:
|
||||||
|
# No encoder input required at this step
|
||||||
|
break
|
||||||
|
if start_pos + length <= cache_length:
|
||||||
|
# The encode input is already processed
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_idx = max(cache_length - start_pos, 0)
|
||||||
|
end_idx = min(cache_length - start_pos + input_length, length)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
image_position.id in self.encoder_cache[i]
|
||||||
|
), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}"
|
||||||
|
encoder_output = self.encoder_cache[i][image_position.id]
|
||||||
|
|
||||||
|
is_embed = image_position.is_embed
|
||||||
|
if is_embed is not None:
|
||||||
|
is_embed = is_embed[start_idx:end_idx]
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}"
|
||||||
|
)
|
||||||
|
|
||||||
|
embeds = gather_image_embeds(
|
||||||
|
encoder_output[start_idx:end_idx],
|
||||||
|
is_embed=is_embed,
|
||||||
|
)
|
||||||
|
if embeds is not None:
|
||||||
|
chunks.append(embeds)
|
||||||
|
|
||||||
|
if end_idx == length:
|
||||||
|
self.cache_entries_to_free.append((i, image_position.id))
|
||||||
|
self.image_positions[i][image_position.id] = None
|
||||||
|
|
||||||
|
if len(chunks) == 0:
|
||||||
|
return None
|
||||||
|
return torch.cat(chunks, dim=0).to(device)
|
||||||
|
|
||||||
|
def free_encoder_cache(self):
|
||||||
|
for i, image_id in self.cache_entries_to_free:
|
||||||
|
self.encoder_cache[i].pop(image_id, None)
|
||||||
|
|
||||||
|
self.cache_entries_to_free = []
|
||||||
|
|
||||||
|
# release any freed GPU memory immediately?
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLM(FlashCausalLM):
|
class VlmCausalLM(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -392,6 +703,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
batch_class=VlmCausalLMBatch,
|
batch_class=VlmCausalLMBatch,
|
||||||
revision,
|
revision,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
|
support_chunking: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if PREFIX_CACHING:
|
if PREFIX_CACHING:
|
||||||
@ -409,8 +721,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
# FIXME: VLM do not work with context chunking yet
|
support_chunking=support_chunking,
|
||||||
support_chunking=False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -418,6 +729,227 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||||
return self.batch_class
|
return self.batch_class
|
||||||
|
|
||||||
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||||
|
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
|
||||||
|
input_lengths = [max_s] * bs
|
||||||
|
cache_lengths = [0] * bs
|
||||||
|
config = getattr(self.model.config, "text_config", self.model.config)
|
||||||
|
if max_bs is None:
|
||||||
|
inputs_embeds = torch.zeros(
|
||||||
|
(bs, config.hidden_size),
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
|
config = getattr(self.model, "config", None)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None) if config else None
|
||||||
|
if ( # mrope have position_ids per section, if so repeat n times
|
||||||
|
isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope"
|
||||||
|
):
|
||||||
|
n_sections = len(self.model.config.rope_scaling["mrope_section"])
|
||||||
|
position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)
|
||||||
|
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||||
|
input_lengths_tensor = (
|
||||||
|
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
|
)
|
||||||
|
cache_lengths_tensor = torch.zeros(
|
||||||
|
bs, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
block_tables = torch.arange(
|
||||||
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
|
).repeat(bs)
|
||||||
|
block_tables = block_tables.reshape((bs, max_bt))
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
|
max_current_length=max_s,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if bs > max_bs:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
||||||
|
)
|
||||||
|
inputs_embeds = self.cuda_graphs[max_bs]["inputs_embeds"][:bs]
|
||||||
|
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
||||||
|
else:
|
||||||
|
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
|
||||||
|
slots = self.cuda_graphs[max_bs]["slots"][:bs]
|
||||||
|
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
|
||||||
|
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
|
||||||
|
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
|
create_decode_state_cuda_graphs,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_tables_ptr = torch.zeros(
|
||||||
|
bs + 1, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||||
|
state = create_decode_state_cuda_graphs(
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
block_tables=block_tables,
|
||||||
|
block_tables_ptr=block_tables_ptr,
|
||||||
|
last_page_len=last_page_len,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
state = None
|
||||||
|
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
self.cuda_graphs[bs] = {
|
||||||
|
"inputs_embeds": inputs_embeds,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"kv_cache": self.kv_cache,
|
||||||
|
"block_tables": block_tables,
|
||||||
|
"slots": slots,
|
||||||
|
"input_lengths": input_lengths_tensor,
|
||||||
|
"cache_lengths": cache_lengths_tensor,
|
||||||
|
"state": state,
|
||||||
|
"graph": graph,
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# Run once outside to warmup
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=block_tables,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
|
state=state,
|
||||||
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
|
):
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths_tensor,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=1,
|
||||||
|
max_k=max_s,
|
||||||
|
)
|
||||||
|
self.model.forward(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
lm_head_indices=None,
|
||||||
|
)
|
||||||
|
del seqlen
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths_tensor,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=1,
|
||||||
|
max_k=max_s,
|
||||||
|
)
|
||||||
|
logits, speculative_logits = self.model.forward(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
lm_head_indices=None,
|
||||||
|
)
|
||||||
|
self.cuda_graphs[bs]["logits"] = logits
|
||||||
|
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def get_vision_embeds(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
pixel_attention_mask: torch.Tensor,
|
||||||
|
image_sizes: torch.Tensor,
|
||||||
|
image_grid_thw: torch.Tensor,
|
||||||
|
):
|
||||||
|
embeds = self.model.get_vision_embeds(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
pixel_attention_mask=pixel_attention_mask,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
)
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
def get_inputs_embeds(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
vision_embeds: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
return self.model.get_inputs_embeds(
|
||||||
|
input_ids=input_ids,
|
||||||
|
vision_embeds=vision_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_images(self, batch):
|
||||||
|
if batch.pixel_values is not None:
|
||||||
|
device = batch.input_ids.device
|
||||||
|
for request_id, image_id, image_input in batch.pixel_values:
|
||||||
|
pixel_values = image_input["pixel_values"].to(device)
|
||||||
|
|
||||||
|
if "pixel_attention_mask" in image_input:
|
||||||
|
pixel_attention_mask = image_input["pixel_attention_mask"].to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pixel_attention_mask = None
|
||||||
|
|
||||||
|
if "image_sizes" in image_input:
|
||||||
|
image_sizes = image_input["image_sizes"].to(device)
|
||||||
|
else:
|
||||||
|
image_sizes = None
|
||||||
|
|
||||||
|
if "image_grid_thw" in image_input:
|
||||||
|
image_grid_thw = image_input["image_grid_thw"].to(device)
|
||||||
|
else:
|
||||||
|
image_grid_thw = None
|
||||||
|
|
||||||
|
encoder_outputs = self.get_vision_embeds(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
pixel_attention_mask=pixel_attention_mask,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
)
|
||||||
|
batch.update_encoder_cache(
|
||||||
|
encoder_outputs,
|
||||||
|
request_id,
|
||||||
|
batch.image_positions[request_id][image_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
batch.image_sizes = None
|
||||||
|
|
||||||
|
def set_inputs_embeds(self, batch):
|
||||||
|
if batch.has_image_inputs:
|
||||||
|
self.encode_images(batch)
|
||||||
|
vision_embeds = batch.gather_vision_embeds()
|
||||||
|
batch.has_image_inputs = False
|
||||||
|
else:
|
||||||
|
vision_embeds = None
|
||||||
|
|
||||||
|
inputs_embeds = self.get_inputs_embeds(
|
||||||
|
batch.input_ids, vision_embeds=vision_embeds
|
||||||
|
)
|
||||||
|
|
||||||
|
batch.inputs_embeds = inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: VlmCausalLMBatch,
|
batch: VlmCausalLMBatch,
|
||||||
@ -468,6 +1000,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
position_ids = new_position_ids
|
position_ids = new_position_ids
|
||||||
else:
|
else:
|
||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
|
inputs_embeds = batch.inputs_embeds
|
||||||
position_ids = batch.position_ids
|
position_ids = batch.position_ids
|
||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache = self.kv_cache
|
kv_cache = self.kv_cache
|
||||||
@ -485,13 +1018,17 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
batch.position_ids = position_ids
|
batch.position_ids = position_ids
|
||||||
|
|
||||||
|
attention_mask = None
|
||||||
|
attention_mask_forward = None
|
||||||
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
|
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
|
||||||
# Get the mask, needed for flashinfer.
|
|
||||||
attention_mask = self.model.get_attention_mask(
|
attention_mask = self.model.get_attention_mask(
|
||||||
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
|
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
|
||||||
).reshape(-1)
|
)
|
||||||
else:
|
min_dtype = torch.finfo(self.dtype).min
|
||||||
attention_mask = None
|
attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to(
|
||||||
|
input_ids.device
|
||||||
|
)
|
||||||
|
attention_mask = attention_mask.reshape(-1)
|
||||||
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
@ -526,7 +1063,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
max_k=batch.max_current_length,
|
max_k=batch.max_current_length,
|
||||||
)
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
@ -536,26 +1073,17 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
pixel_values=batch.pixel_values,
|
attention_mask=attention_mask_forward,
|
||||||
pixel_attention_mask=batch.pixel_attention_mask,
|
|
||||||
image_sizes=batch.image_sizes,
|
|
||||||
image_grid_thw=batch.image_grid_thw,
|
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
if batch.pixel_values is not None:
|
batch.image_grid_thw = None
|
||||||
batch.pixel_values = None
|
batch.free_encoder_cache()
|
||||||
if batch.pixel_attention_mask is not None:
|
|
||||||
batch.pixel_attention_mask = None
|
|
||||||
if batch.image_sizes is not None:
|
|
||||||
batch.image_sizes = None
|
|
||||||
if batch.image_grid_thw is not None:
|
|
||||||
batch.image_grid_thw = None
|
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
cuda_graph["inputs_embeds"][: inputs_embeds.shape[0]] = inputs_embeds
|
||||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
@ -600,4 +1128,6 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
logits = cuda_graph["logits"][:bs]
|
logits = cuda_graph["logits"][:bs]
|
||||||
|
|
||||||
|
batch.free_encoder_cache()
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -18,7 +18,6 @@ from text_generation_server.utils.adapter import AdapterInfo
|
|||||||
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
|
||||||
from text_generation_server.models.vlm_causal_lm import (
|
from text_generation_server.models.vlm_causal_lm import (
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
)
|
)
|
||||||
@ -26,7 +25,6 @@ try:
|
|||||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||||
|
|
||||||
VLM_BATCH_TYPES = {
|
VLM_BATCH_TYPES = {
|
||||||
PaliGemmaBatch,
|
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
IdeficsCausalLMBatch,
|
IdeficsCausalLMBatch,
|
||||||
MllamaCausalLMBatch,
|
MllamaCausalLMBatch,
|
||||||
|
Loading…
Reference in New Issue
Block a user