mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Llava-next: Added flash_attention_recompute option (#220)
This commit is contained in:
parent
2299b739fe
commit
ad7c620f0f
@ -87,6 +87,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
token_idx: Optional[torch.Tensor] = None,
|
token_idx: Optional[torch.Tensor] = None,
|
||||||
use_flash_attention: Optional[bool] = False,
|
use_flash_attention: Optional[bool] = False,
|
||||||
|
flash_attention_recompute: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
if token_idx is not None:
|
if token_idx is not None:
|
||||||
@ -109,7 +110,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
token_idx=token_idx,
|
token_idx=token_idx,
|
||||||
use_flash_attention=use_flash_attention,
|
use_flash_attention=use_flash_attention,
|
||||||
flash_attention_recompute=use_flash_attention,
|
flash_attention_recompute=flash_attention_recompute,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = outputs[0]
|
logits = outputs[0]
|
||||||
@ -149,6 +150,8 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
use_flash_attention = kwargs.get("use_flash_attention", False)
|
use_flash_attention = kwargs.get("use_flash_attention", False)
|
||||||
|
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
labels = kwargs.get("labels", None)
|
labels = kwargs.get("labels", None)
|
||||||
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
|
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
|
||||||
@ -169,7 +172,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
batch_size, num_patches, num_channels, height, width = pixel_values.shape
|
batch_size, num_patches, num_channels, height, width = pixel_values.shape
|
||||||
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
|
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
|
||||||
image_features = self.vision_tower(
|
image_features = self.vision_tower(
|
||||||
reshaped_pixel_values, output_hidden_states=True, use_flash_attention=use_flash_attention
|
reshaped_pixel_values,
|
||||||
|
output_hidden_states=True,
|
||||||
|
use_flash_attention=use_flash_attention,
|
||||||
|
flash_attention_recompute=flash_attention_recompute,
|
||||||
)
|
)
|
||||||
|
|
||||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||||
@ -283,6 +289,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
"token_idx": token_idx,
|
"token_idx": token_idx,
|
||||||
"labels": labels,
|
"labels": labels,
|
||||||
"use_flash_attention": use_flash_attention,
|
"use_flash_attention": use_flash_attention,
|
||||||
|
"flash_attention_recompute": flash_attention_recompute,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user