Llava-next: Added flash_attention_recompute option (#220)

This commit is contained in:
Thanaji Rao Thakkalapelli 2024-09-06 13:20:07 -07:00 committed by GitHub
parent 2299b739fe
commit ad7c620f0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -87,6 +87,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None, token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False, use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
): ):
if token_idx is not None: if token_idx is not None:
@ -109,7 +110,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
return_dict=return_dict, return_dict=return_dict,
token_idx=token_idx, token_idx=token_idx,
use_flash_attention=use_flash_attention, use_flash_attention=use_flash_attention,
flash_attention_recompute=use_flash_attention, flash_attention_recompute=flash_attention_recompute,
) )
logits = outputs[0] logits = outputs[0]
@ -149,6 +150,8 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
) )
else: else:
use_flash_attention = kwargs.get("use_flash_attention", False) use_flash_attention = kwargs.get("use_flash_attention", False)
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None) labels = kwargs.get("labels", None)
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
@ -169,7 +172,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
batch_size, num_patches, num_channels, height, width = pixel_values.shape batch_size, num_patches, num_channels, height, width = pixel_values.shape
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
image_features = self.vision_tower( image_features = self.vision_tower(
reshaped_pixel_values, output_hidden_states=True, use_flash_attention=use_flash_attention reshaped_pixel_values,
output_hidden_states=True,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
) )
selected_image_feature = image_features.hidden_states[vision_feature_layer] selected_image_feature = image_features.hidden_states[vision_feature_layer]
@ -283,6 +289,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
"token_idx": token_idx, "token_idx": token_idx,
"labels": labels, "labels": labels,
"use_flash_attention": use_flash_attention, "use_flash_attention": use_flash_attention,
"flash_attention_recompute": flash_attention_recompute,
} }
) )