mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
llama4 and some issue fix
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
b09d4cc142
commit
93e5e35f9d
@ -1356,55 +1356,36 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
hidden_state = self.vision_model(pixel_values)
|
hidden_state = self.vision_model(pixel_values)
|
||||||
return hidden_state
|
return hidden_state
|
||||||
|
|
||||||
def forward(
|
def get_vision_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
pixel_values: torch.FloatTensor,
|
||||||
|
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
image_features = self.get_image_features(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
vision_feature_layer=self.config.vision_config.vision_feature_layer,
|
||||||
|
vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
)
|
||||||
|
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||||
|
image_features = self.multi_modal_projector(vision_flat)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def get_inputs_embeds(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
vision_embeds: torch.Tensor = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
pixel_attention_mask=None,
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
):
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor] = None,
|
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
slots: torch.Tensor = None,
|
|
||||||
seqlen: Seqlen = None,
|
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
|
||||||
vision_feature_select_strategy: Optional[str] = None,
|
|
||||||
image_sizes: torch.Tensor = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
|
||||||
**lm_kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
|
|
||||||
def _get_padding_mask(input_ids, pad_token_id=0):
|
|
||||||
return (input_ids != pad_token_id).long()
|
|
||||||
|
|
||||||
attention_mask = _get_padding_mask(input_ids)
|
|
||||||
attention_mask = attention_mask.view(seqlen.input_lengths.shape[0], -1)
|
|
||||||
inputs_embeds = self.text_model.model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.model.embed_tokens(input_ids)
|
||||||
vision_feature_layer = (
|
|
||||||
vision_feature_layer
|
|
||||||
if vision_feature_layer is not None
|
|
||||||
else self.config.vision_config.vision_feature_layer
|
|
||||||
)
|
|
||||||
vision_feature_select_strategy = (
|
|
||||||
vision_feature_select_strategy
|
|
||||||
if vision_feature_select_strategy is not None
|
|
||||||
else self.config.vision_config.vision_feature_select_strategy
|
|
||||||
)
|
|
||||||
|
|
||||||
if pixel_values is not None:
|
if vision_embeds is not None:
|
||||||
image_features = self.get_image_features(
|
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||||
pixel_values=pixel_values,
|
# that simply don't exist
|
||||||
vision_feature_layer=vision_feature_layer,
|
|
||||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
|
||||||
image_sizes=image_sizes,
|
|
||||||
)
|
|
||||||
original_inputs_embeds_shape = inputs_embeds.shape
|
original_inputs_embeds_shape = inputs_embeds.shape
|
||||||
|
|
||||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
|
||||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
|
||||||
|
|
||||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||||
-1
|
-1
|
||||||
)
|
)
|
||||||
@ -1414,19 +1395,33 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||||
num_tokens_to_fill = final_mask_1d.sum()
|
num_tokens_to_fill = final_mask_1d.sum()
|
||||||
|
|
||||||
if num_tokens_to_fill != projected_vision_flat.size(0):
|
if num_tokens_to_fill != vision_embeds.size(0):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||||
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
f"but multi_modal_projector returned {vision_embeds.size(0)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
|
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
|
||||||
-1, inputs_embeds.size(-1)
|
-1, inputs_embeds.size(-1)
|
||||||
)
|
)
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(
|
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)
|
||||||
expanded_mask, projected_vision_flat
|
|
||||||
)
|
|
||||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor] = None,
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
slots: torch.Tensor = None,
|
||||||
|
seqlen: Seqlen = None,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
**lm_kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
logits, speculative_logits = self.text_model(
|
logits, speculative_logits = self.text_model(
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
|
@ -285,7 +285,7 @@ def scatter_image_embeds(
|
|||||||
(is_embed.shape[0], embeds.shape[-1]),
|
(is_embed.shape[0], embeds.shape[-1]),
|
||||||
fill_value=torch.nan,
|
fill_value=torch.nan,
|
||||||
)
|
)
|
||||||
placeholders[is_embed] = embeds
|
placeholders[is_embed.to(embeds.device)] = embeds
|
||||||
return placeholders
|
return placeholders
|
||||||
|
|
||||||
|
|
||||||
@ -294,7 +294,7 @@ def gather_image_embeds(
|
|||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
if is_embed is None:
|
if is_embed is None:
|
||||||
return embeds
|
return embeds
|
||||||
sel = embeds[is_embed]
|
sel = embeds[is_embed.to(embeds.device)]
|
||||||
return sel if sel.numel() else None
|
return sel if sel.numel() else None
|
||||||
|
|
||||||
|
|
||||||
@ -553,8 +553,12 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def prepare_for_prefill(self):
|
def prepare_for_prefill(
|
||||||
super().prepare_for_prefill()
|
self, max_padded_input_len, max_padded_bs, max_total_tokens
|
||||||
|
):
|
||||||
|
super().prepare_for_prefill(
|
||||||
|
max_padded_input_len, max_padded_bs, max_total_tokens
|
||||||
|
)
|
||||||
|
|
||||||
self.has_image_inputs = False
|
self.has_image_inputs = False
|
||||||
self.cache_entries_to_free = []
|
self.cache_entries_to_free = []
|
||||||
@ -1003,6 +1007,9 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
input_ids.device
|
input_ids.device
|
||||||
)
|
)
|
||||||
attention_mask = attention_mask.reshape(-1)
|
attention_mask = attention_mask.reshape(-1)
|
||||||
|
if self.model.config.model_type == "llama4":
|
||||||
|
attention_mask = (input_ids != 0).long()
|
||||||
|
attention_mask_forward = attention_mask.view(input_lengths.shape[0], -1)
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
|
@ -46,6 +46,13 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
|||||||
aspect_ratio_mask: Optional[torch.Tensor] = None
|
aspect_ratio_mask: Optional[torch.Tensor] = None
|
||||||
cross_attention_states: Optional[torch.Tensor] = None
|
cross_attention_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def prepare_for_prefill(
|
||||||
|
self, max_padded_input_len, max_padded_bs, max_total_tokens
|
||||||
|
):
|
||||||
|
super(FlashVlmCausalLMBatch, self).prepare_for_prefill(
|
||||||
|
max_padded_input_len, max_padded_bs, max_total_tokens
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches, padded_total_bs: int = 0):
|
def concatenate(cls, batches, padded_total_bs: int = 0):
|
||||||
|
Loading…
Reference in New Issue
Block a user