diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index a4ad8f59..d45cd6ce 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -111,7 +111,7 @@ class MistralAttention(torch.nn.Module): ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size - if hasattr(config, "head_dim"): + if getattr(config, "head_dim", None) is not None: self.head_size = config.head_dim else: self.head_size = self.hidden_size // self.num_heads diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 086c05e7..5bd2292e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -1050,8 +1050,6 @@ class FlashVlmCausalLM(FlashCausalLM): attention_mask=attention_mask_forward, **kwargs, ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None batch.image_grid_thw = None batch.free_encoder_cache() return logits, speculative_logits