diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index a4ad8f59..d45cd6ce 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -111,7 +111,7 @@ class MistralAttention(torch.nn.Module): ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size - if hasattr(config, "head_dim"): + if getattr(config, "head_dim", None) is not None: self.head_size = config.head_dim else: self.head_size = self.hidden_size // self.num_heads diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 086c05e7..5bd2292e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -1050,8 +1050,6 @@ class FlashVlmCausalLM(FlashCausalLM): attention_mask=attention_mask_forward, **kwargs, ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None batch.image_grid_thw = None batch.free_encoder_cache() return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/utils/debug.py b/backends/gaudi/server/text_generation_server/utils/debug.py index 8bbcad6a..690da54e 100644 --- a/backends/gaudi/server/text_generation_server/utils/debug.py +++ b/backends/gaudi/server/text_generation_server/utils/debug.py @@ -4,8 +4,8 @@ import os import glob import time -from optimum.habana.utils import to_gb_rounded import habana_frameworks.torch as htorch +import numpy as np START_TS = None DBG_TRACE_FILENAME = os.environ.get("DBG_TRACE_FILENAME") @@ -14,6 +14,19 @@ if "GRAPH_VISUALIZATION" in os.environ: os.remove(f) +def to_gb_rounded(mem: float) -> float: + """ + Rounds and converts to GB. + + Args: + mem (float): memory in bytes + + Returns: + float: memory in GB rounded to the second decimal + """ + return np.round(mem / 1024**3, 2) + + def count_hpu_graphs(): return len(glob.glob(".graph_dumps/*PreGraph*"))