From fcbc6876c0ced3557d6d93f87824fb9e2c7e0d89 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 25 Jun 2024 12:24:45 +0000 Subject: [PATCH] No need for cache_manager anymore. --- .../models/cache_manager.py | 157 ------------------ 1 file changed, 157 deletions(-) delete mode 100644 server/text_generation_server/models/cache_manager.py diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py deleted file mode 100644 index 518f9abb..00000000 --- a/server/text_generation_server/models/cache_manager.py +++ /dev/null @@ -1,157 +0,0 @@ -import math -import torch - -from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE - -# Will be set in warmup -CACHE_MANAGER: Optional["CacheManager"] = None - - -class CacheManager: - def __init__( - self, - num_blocks: int, - num_layers: int, - num_heads: int, - head_size: int, - repeat_slots: bool, - dtype: torch.dtype, - device: torch.device, - ): - self.block_size = BLOCK_SIZE - self.num_blocks = num_blocks - self.repeat_slots = repeat_slots - - element_size = torch.tensor([], dtype=dtype).element_size() - if SYSTEM == "xpu": - x = 1 - else: - x = self.block_size // element_size - - if FLASH_DECODING: - self.kv_cache = [ - ( - torch.empty( - (num_blocks, self.block_size, num_heads, head_size), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, self.block_size, num_heads, head_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - else: - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, head_size // x, self.block_size, x), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, head_size, self.block_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") - self.slots = torch.arange( - 0, num_blocks * self.block_size, dtype=torch.int64 - ).view(num_blocks, self.block_size) - - def allocate( - self, - needed_blocks_slots: List[Tuple[int, int]], - blocks: int, - max_blocks: int, - device: torch.device, - ): - # Get free blocks indices by finding values in mask that are not set to 0 - free_block_indices = self.free_block_mask.nonzero() - if blocks > len(free_block_indices): - raise RuntimeError( - f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks" - ) - - # Slice by the number of required blocks - block_indices = free_block_indices[:blocks] - block_indices = block_indices.flatten() - - # Padded block tables - block_tables_tensor = torch.zeros( - (len(needed_blocks_slots), max_blocks), dtype=torch.int32 - ) - - # Allocate paged attention blocks - cumulative_blocks = 0 - slots = [] - block_tables = [] - for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots): - # Get allocated blocks for this sequence - allocated_blocks = block_indices[ - cumulative_blocks : cumulative_blocks + needed_blocks - ] - # Get slots for the allocated blocks - all_slots = self.slots[allocated_blocks].flatten() - - # Repeat slots in the case of context sliding window - if needed_slots > len(all_slots) and self.repeat_slots: - repeats = math.ceil(needed_slots / len(all_slots)) - all_slots = all_slots.repeat(repeats) - - allocated_slots = all_slots[:needed_slots] - - slots.append(allocated_slots) - block_tables.append(allocated_blocks.tolist()) - block_tables_tensor[i, :needed_blocks] = allocated_blocks - cumulative_blocks += needed_blocks - - block_tables = block_tables - block_tables_tensor = block_tables_tensor.to(device) - slots = torch.concat(slots).to(device) - - # Allocate the required number of blocks by setting the mask to 0 - self.free_block_mask[block_indices] = 0 - - return block_tables, block_tables_tensor, slots - - def free(self, block_indices: Optional[List[int]]): - if block_indices is not None and block_indices: - # Reset mask - self.free_block_mask[block_indices] = 1 - - -def set_cache_manager( - num_blocks: int, - num_layers: int, - num_heads: int, - head_size: int, - repeat_slots: bool, - dtype: torch.dtype, - device: torch.device, -) -> CacheManager: - global CACHE_MANAGER - if CACHE_MANAGER is not None: - del CACHE_MANAGER - torch.cuda.empty_cache() - - CACHE_MANAGER = CacheManager( - num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device - ) - return CACHE_MANAGER - - -def get_cache_manager() -> CacheManager: - global CACHE_MANAGER - if CACHE_MANAGER is None: - raise RuntimeError("cache manager was not initialized") - - return CACHE_MANAGER