diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index cb23bfcd..acf17695 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -640,6 +640,10 @@ class FlashCausalLMBatch(Batch): device=batches[0].next_token_chooser.device, ) + # Needed to avoid dropping blocks when the batches will go out of scope + for b in batches: + b.block_tables = None + return FlashCausalLMBatch( batch_id=batches[0].batch_id, requests=requests,