From b1831d5f97d239c099a4d1bd9e6547cb2063884c Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 30 Jun 2023 17:43:40 +0200 Subject: [PATCH] double free --- README.md | 4 ++-- server/text_generation_server/models/flash_causal_lm.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8c8d9773..b74d2617 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ to power LLMs api-inference widgets. - Tensor Parallelism for faster inference on multiple GPUs - Token streaming using Server-Sent Events (SSE) - [Continuous batching of incoming requests](https://github.com/huggingface/text-generation-inference/tree/main/router) for increased total throughput -- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) on the most popular architectures -- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) +- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures +- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323) - [Safetensors](https://github.com/huggingface/safetensors) weight loading - Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index acf17695..94b14f85 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -481,6 +481,8 @@ class FlashCausalLMBatch(Batch): block_indices_to_free.extend(self.block_tables[i]) # Free blocks CACHE_MANAGER.free(block_indices_to_free) + # Needed to avoid dropping blocks when the batches will go out of scope + self.block_tables = None # Index into tensors input_ids = self.input_ids[indices] @@ -675,7 +677,7 @@ class FlashCausalLMBatch(Batch): ) def __del__(self): - if self.block_tables is not None: + if self.block_tables is not None and self.block_tables: global CACHE_MANAGER # Free blocks CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))