diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4e5804f51..d034d4721 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -991,6 +991,7 @@ class FlashCausalLM(Model): if stopped: del batch + torch.cuda.empty_cache() # No need to return a batch if we know that all requests stopped return generations, None