diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 5c7b8bc01..8a5668a5d 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1785,7 +1785,7 @@ class FlashCausalLM(Model): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False + kwargs["bypass_hpu_graphs"] = batch.prefilling logits, speculative_logits = self.model.forward( input_ids=input_ids, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 6a0661851..c1ea36f29 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -455,7 +455,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False + kwargs["bypass_hpu_graphs"] = batch.prefilling if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots