diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 942f7459..8497f807 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -173,8 +173,9 @@ struct Args { /// for end users. #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, - #[clap(default_value = "3000", long, short, env)] + /// The port to listen on. + #[clap(default_value = "3000", long, short, env)] port: u16, /// The name of the socket for gRPC communication between the webserver diff --git a/router/client/src/client.rs b/router/client/src/client.rs index c5396cc4..b5e0ccc0 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -113,7 +113,7 @@ impl Client { requests.push(Request { id: 0, // We truncate the input on the server side to be sure that it has the correct size - inputs: "test".to_string().repeat(max_input_length as usize), + inputs: "_test ".to_string().repeat(max_input_length as usize), truncate: min(max_input_length, max_prefill_tokens - n_tokens), // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { diff --git a/server/Makefile-vllm b/server/Makefile-vllm index b9725ba3..af750733 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -5,7 +5,7 @@ vllm: git clone https://github.com/OlivierDehaene/vllm.git build-vllm: vllm - cd vllm && git fetch && git checkout $(flash_att_commit) + cd vllm && git fetch && git checkout $(vllm_commit) cd vllm && python setup.py build install-vllm: build-vllm diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index fc4c0d3a..79fcd3aa 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -19,7 +19,6 @@ class Cache: def delete(self, batch_id: int): batch = self.pop(batch_id) if batch is not None: - batch.free() del batch def clear(self): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e4fe4517..e932b5f1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -63,24 +63,48 @@ class CacheManager: 0, num_blocks * self.block_size, dtype=torch.int32 ).view(num_blocks, self.block_size) - def allocate(self, num_blocks: int) -> Tuple[torch.Tensor, torch.Tensor]: + def allocate(self, batch: "FlashCausalLMBatch"): # Get free blocks indices by finding values in mask that are not set to 0 free_block_indices = self.free_block_mask.nonzero() assert ( - len(free_block_indices) >= num_blocks - ), f"Out of available cache blocks: asked {num_blocks}, only {len(free_block_indices)} free blocks" + len(free_block_indices) >= batch.blocks + ), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks" + + # Slice by the number of required blocks + block_indices = free_block_indices[: batch.blocks] + block_indices = block_indices.flatten() + + # Padded block tables + block_tables_tensor = torch.zeros( + (len(batch), batch.max_blocks), dtype=torch.int32 + ) + + # Allocate paged attention blocks + cumulative_blocks = 0 + slots = [] + block_tables = [] + for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots): + # Get allocated blocks for this sequence + allocated_blocks = block_indices[ + cumulative_blocks : cumulative_blocks + needed_blocks + ] + # Get slots for the allocated blocks + allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots] + + slots.append(allocated_slots) + block_tables.append(allocated_blocks.tolist()) + block_tables_tensor[i, :needed_blocks] = allocated_blocks + + batch.needed_blocks_slots = None + batch.block_tables = block_tables + batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device) + batch.slots = torch.concat(slots).to(batch.input_ids.device) # Allocate the required number of blocks by setting the mask to 0 - block_indices = free_block_indices[:num_blocks] self.free_block_mask[block_indices] = 0 - # Get slots for the allocated blocks - slots = self.slots[block_indices].flatten() - - return block_indices.flatten(), slots - def free(self, block_indices: Optional[List[int]]): - if block_indices is not None: + if block_indices is not None and block_indices: # Reset mask self.free_block_mask[block_indices] = 1 @@ -448,12 +472,14 @@ class FlashCausalLMBatch(Batch): max_blocks = max(max_blocks, len(request_block_table)) global CACHE_MANAGER + block_indices_to_free = [] # Iterate on all requests for i, r in enumerate(self.requests): # Filter requests that are not part of the new batch if r.id not in requests_idx_mapping.keys(): - # Free blocks - CACHE_MANAGER.free(self.block_tables[i]) + block_indices_to_free.extend(self.block_tables[i]) + # Free blocks + CACHE_MANAGER.free(block_indices_to_free) # Index into tensors input_ids = self.input_ids[indices] @@ -643,7 +669,7 @@ class FlashCausalLMBatch(Batch): max_blocks=max_blocks, ) - def free(self): + def __del__(self): if self.block_tables is not None: global CACHE_MANAGER # Free blocks @@ -703,10 +729,10 @@ class FlashCausalLM(Model): logger.exception( f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " f"prefill tokens. " - f"You need to decrease `--max-batch-total-tokens` and `--max-batch-prefill-tokens`" + f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`" ) raise e - batch.free() + del batch def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: return self.tokenizer.decode( @@ -749,33 +775,8 @@ class FlashCausalLM(Model): prefill_logprobs = batch.prefill_next_token_indices is not None if batch.needed_blocks_slots: - # Padded block tables - block_tables_tensor = torch.zeros( - (len(batch), batch.max_blocks), dtype=torch.int32 - ) - - # Allocate paged attention blocks - slots = [] - block_tables = [] - try: - for i, (needed_blocks, needed_slots) in enumerate( - batch.needed_blocks_slots - ): - allocated_blocks, allocated_slots = CACHE_MANAGER.allocate( - needed_blocks - ) - slots.append(allocated_slots[:needed_slots]) - block_tables.append(allocated_blocks.tolist()) - block_tables_tensor[i, :needed_blocks] = allocated_blocks - except Exception as e: - for blocks in block_tables: - CACHE_MANAGER.free(blocks) - raise e - - batch.needed_blocks_slots = None - batch.block_tables = block_tables - batch.block_tables_tensor = block_tables_tensor.to(self.device) - batch.slots = torch.concat(slots).to(self.device) + # Allocate blocks to this batch + CACHE_MANAGER.allocate(batch) out = self.forward( batch.input_ids, @@ -990,7 +991,7 @@ class FlashCausalLM(Model): batch.all_input_ids[i] = all_input_ids if stopped: - batch.free() + del batch # No need to return a batch if we know that all requests stopped return generations, None diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index c35e15d3..28ca8147 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -35,9 +35,6 @@ class Batch(ABC): def concatenate(cls, batches: List["Batch"]) -> "Batch": raise NotImplementedError - def free(self): - pass - @abstractmethod def __len__(self): raise NotImplementedError diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 378ac841..6cc5beeb 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -65,12 +65,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) - try: - generations, next_batch = self.model.generate_token(batch) - except Exception as e: - batch.free() - raise e - + generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) return generate_pb2.PrefillResponse( @@ -93,20 +88,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): raise ValueError("All batches are empty") if len(batches) > 1: - try: - batch = self.model.batch_type.concatenate(batches) - except Exception as e: - [batch.free() for batch in batches] - raise e + batch = self.model.batch_type.concatenate(batches) else: batch = batches[0] - try: - generations, next_batch = self.model.generate_token(batch) - except Exception as e: - batch.free() - raise e - + generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) return generate_pb2.DecodeResponse(