diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index 79fcd3aa5..4504733e5 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -1,3 +1,5 @@ +import torch + from typing import Dict, Optional, TypeVar from text_generation_server.models.types import Batch @@ -20,6 +22,8 @@ class Cache: batch = self.pop(batch_id) if batch is not None: del batch + if torch.cuda.is_available(): + torch.cuda.empty_cache() def clear(self): keys = list(self.cache.keys()) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bf5f5bbe1..bebd3df56 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -638,6 +638,8 @@ class FlashCausalLMBatch(Batch): # Needed to avoid dropping blocks when the batches will go out of scope for b in batches: b.block_tables = None + del b + torch.cuda.empty_cache() return FlashCausalLMBatch( batch_id=batches[0].batch_id, @@ -732,6 +734,7 @@ class FlashCausalLM(Model): ) raise e del batch + torch.cuda.empty_cache() def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: return self.tokenizer.decode( @@ -775,16 +778,21 @@ class FlashCausalLM(Model): # Allocate blocks to this batch CACHE_MANAGER.allocate(batch) - out = self.forward( - batch.input_ids, - batch.position_ids, - batch.cu_seqlen_prefill, - batch.block_tables_tensor, - batch.slots[batch.slot_indices], - batch.input_lengths_tensor, - batch.max_seqlen, - batch.prefill_head_indices, - ) + try: + out = self.forward( + batch.input_ids, + batch.position_ids, + batch.cu_seqlen_prefill, + batch.block_tables_tensor, + batch.slots[batch.slot_indices], + batch.input_lengths_tensor, + batch.max_seqlen, + batch.prefill_head_indices, + ) + except Exception as e: + del batch + torch.cuda.empty_cache() + raise e if prefill: next_token_logits = (