diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 27e1c6724..3a0dc15e0 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1398,7 +1398,7 @@ class FlashCausalLM(Model): self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype if htorch.utils.internal.is_lazy(): - htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) + htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=False) environment.set_model_config(self.config) self.use_contiguous_pa = ( os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"