diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ee726faf..26be1875 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -632,7 +632,7 @@ class CausalLM(Model): model = model.eval().to(device) #wrap in hpu_graph only if self.enable_hpu_graph is set if self.enable_hpu_graph: - model = wrap_in_hpu_graph(model) + model = wrap_in_hpu_graph(model, disable_tensor_cache=True) if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: self.is_optimized_for_gaudi = True