mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
llama4 and some issue fix
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
b09d4cc142
commit
93e5e35f9d
@ -1356,55 +1356,36 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
hidden_state = self.vision_model(pixel_values)
|
||||
return hidden_state
|
||||
|
||||
def forward(
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask=None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor] = None,
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
slots: torch.Tensor = None,
|
||||
seqlen: Seqlen = None,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
image_sizes: torch.Tensor = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
**lm_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
def _get_padding_mask(input_ids, pad_token_id=0):
|
||||
return (input_ids != pad_token_id).long()
|
||||
|
||||
attention_mask = _get_padding_mask(input_ids)
|
||||
attention_mask = attention_mask.view(seqlen.input_lengths.shape[0], -1)
|
||||
inputs_embeds = self.text_model.model.embed_tokens(input_ids)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
vision_feature_layer=self.config.vision_config.vision_feature_layer,
|
||||
vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
original_inputs_embeds_shape = inputs_embeds.shape
|
||||
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||
image_features = self.multi_modal_projector(vision_flat)
|
||||
return image_features
|
||||
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.model.embed_tokens(input_ids)
|
||||
|
||||
if vision_embeds is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
original_inputs_embeds_shape = inputs_embeds.shape
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||
-1
|
||||
)
|
||||
@ -1414,19 +1395,33 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||
num_tokens_to_fill = final_mask_1d.sum()
|
||||
|
||||
if num_tokens_to_fill != projected_vision_flat.size(0):
|
||||
if num_tokens_to_fill != vision_embeds.size(0):
|
||||
raise ValueError(
|
||||
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
||||
f"but multi_modal_projector returned {vision_embeds.size(0)}"
|
||||
)
|
||||
|
||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
|
||||
-1, inputs_embeds.size(-1)
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
expanded_mask, projected_vision_flat
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)
|
||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor] = None,
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
slots: torch.Tensor = None,
|
||||
seqlen: Seqlen = None,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
**lm_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
logits, speculative_logits = self.text_model(
|
||||
inputs_embeds,
|
||||
|
@ -285,7 +285,7 @@ def scatter_image_embeds(
|
||||
(is_embed.shape[0], embeds.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
placeholders[is_embed] = embeds
|
||||
placeholders[is_embed.to(embeds.device)] = embeds
|
||||
return placeholders
|
||||
|
||||
|
||||
@ -294,7 +294,7 @@ def gather_image_embeds(
|
||||
) -> Optional[torch.Tensor]:
|
||||
if is_embed is None:
|
||||
return embeds
|
||||
sel = embeds[is_embed]
|
||||
sel = embeds[is_embed.to(embeds.device)]
|
||||
return sel if sel.numel() else None
|
||||
|
||||
|
||||
@ -553,8 +553,12 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
|
||||
batch.image_grid_thw = None
|
||||
return batch
|
||||
|
||||
def prepare_for_prefill(self):
|
||||
super().prepare_for_prefill()
|
||||
def prepare_for_prefill(
|
||||
self, max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
):
|
||||
super().prepare_for_prefill(
|
||||
max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
)
|
||||
|
||||
self.has_image_inputs = False
|
||||
self.cache_entries_to_free = []
|
||||
@ -1003,6 +1007,9 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
input_ids.device
|
||||
)
|
||||
attention_mask = attention_mask.reshape(-1)
|
||||
if self.model.config.model_type == "llama4":
|
||||
attention_mask = (input_ids != 0).long()
|
||||
attention_mask_forward = attention_mask.view(input_lengths.shape[0], -1)
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
|
@ -46,6 +46,13 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
||||
aspect_ratio_mask: Optional[torch.Tensor] = None
|
||||
cross_attention_states: Optional[torch.Tensor] = None
|
||||
|
||||
def prepare_for_prefill(
|
||||
self, max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
):
|
||||
super(FlashVlmCausalLMBatch, self).prepare_for_prefill(
|
||||
max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches, padded_total_bs: int = 0):
|
||||
|
Loading…
Reference in New Issue
Block a user