From ddfc02f2a4306d7543d1b1069776ad3ae2e9be71 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 29 Jun 2023 15:50:44 +0200 Subject: [PATCH] add warmup --- launcher/src/main.rs | 14 +- proto/generate.proto | 12 ++ router/client/src/client.rs | 58 ++++++ router/client/src/sharded_client.rs | 21 ++ router/src/infer.rs | 4 +- router/src/main.rs | 19 +- server/text_generation_server/cache.py | 5 +- .../models/causal_lm.py | 2 +- .../models/flash_causal_lm.py | 183 ++++++++++++------ .../models/flash_llama.py | 2 +- server/text_generation_server/models/model.py | 6 + .../models/seq2seq_lm.py | 2 +- server/text_generation_server/models/types.py | 2 +- server/text_generation_server/server.py | 27 ++- 14 files changed, 272 insertions(+), 85 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0684de0f..2deb0e0c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -128,7 +128,10 @@ struct Args { #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, - #[clap(default_value = "32000", long, env)] + /// Limits the number of tokens for the prefill operation. + /// Since this operation take the most memory and is compute bound, it is interesting + /// to limit the number of requests that can be sent. + #[clap(default_value = "4096", long, env)] max_batch_prefill_tokens: u32, /// **IMPORTANT** This is one critical control to allow maximum usage @@ -143,13 +146,6 @@ struct Args { /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100` /// or a single query of `1000` tokens. /// - /// So you don't have to control that finely - /// `max_batch_size` or `max_total_tokens`. In fact you could mostly relax them if you - /// want maximum flexibility. However, for your users if they are asking for the full amount of - /// total tokens, they are likely to wait for a very long time to get a spot - /// in the batch (since they are going to be alone) so setting `max_batch_size` - /// and `max_total_tokens` can still be useful to prevent those long waiting times. - /// /// Overall this number should be the largest possible amount that fits the /// remaining memory (after the model is loaded). Since the actual memory overhead /// depends on other parameters like if you're using quantization, flash attention @@ -448,7 +444,7 @@ fn shard_manager( // We received a shutdown signal if *shutdown.lock().unwrap() { - p.terminate().unwrap(); + p.kill().unwrap(); let _ = p.wait_timeout(Duration::from_secs(90)); tracing::info!("Shard {rank} terminated"); return; diff --git a/proto/generate.proto b/proto/generate.proto index a0f5a75e..5e061941 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -11,6 +11,8 @@ service TextGenerationService { rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); /// Remove requests from a cached batch rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); + /// Warmup the model and compute max cache size + rpc Warmup (WarmupRequest) returns (WarmupResponse); /// Prefill batch and decode first token rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Decode token for a list of prefilled batches @@ -192,3 +194,13 @@ message DecodeResponse { /// Next batch (cached) optional CachedBatch batch = 2; } + +message WarmupRequest { + /// Batch to warmup on + Batch batch = 1; + /// Maximum number of tokens that the client will send + uint32 max_total_tokens = 2; +} + +/// Empty response +message WarmupResponse {} diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 81f023ef..c5396cc4 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -3,6 +3,7 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServi use crate::pb::generate::v1::*; use crate::Result; use grpc_metadata::InjectTelemetryContext; +use std::cmp::min; use tonic::transport::{Channel, Uri}; use tracing::instrument; @@ -94,6 +95,63 @@ impl Client { Ok(filtered_batch.batch) } + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + ) -> Result<()> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + + // Create requests + while n_tokens < max_prefill_tokens { + requests.push(Request { + id: 0, + // We truncate the input on the server side to be sure that it has the correct size + inputs: "test".to_string().repeat(max_input_length as usize), + truncate: min(max_input_length, max_prefill_tokens - n_tokens), + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + watermark: true, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 2, + stop_sequences: vec![], + ignore_eos_token: false, + }), + prefill_logprobs: true, + }); + n_tokens += max_input_length; + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_total_tokens, + }) + .inject_context(); + self.stub.warmup(request).await?.into_inner(); + Ok(()) + } + /// Generate one token for each request in the given batch /// /// Returns Generation for each request in batch diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index b81eed46..9dd173a0 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -87,6 +87,27 @@ impl ShardedClient { join_all(futures).await.pop().unwrap() } + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + ) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens)) + }) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + /// Generate one token for each request in the given batch /// /// Returns Generation for each request in batch diff --git a/router/src/infer.rs b/router/src/infer.rs index 255932ef..8d93d2a1 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -242,6 +242,7 @@ impl Infer { /// Will be launched in a background Tokio task /// /// Batches requests and sends them to the inference server +#[allow(clippy::too_many_arguments)] async fn batching_task( mut client: ShardedClient, waiting_served_ratio: f32, @@ -288,7 +289,8 @@ async fn batching_task( Some((batch_size as f32 * waiting_served_ratio).floor() as usize) }; - let token_budget = max_batch_total_tokens - batch_max_tokens; + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + tracing::info!("{token_budget} {batch_max_tokens}"); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue diff --git a/router/src/main.rs b/router/src/main.rs index 4c9a6740..474f4e06 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -34,7 +34,7 @@ struct Args { max_total_tokens: usize, #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, - #[clap(default_value = "32000", long, env)] + #[clap(default_value = "4096", long, env)] max_batch_prefill_tokens: u32, #[clap(default_value = "32000", long, env)] max_batch_total_tokens: u32, @@ -180,16 +180,23 @@ fn main() -> Result<(), std::io::Error> { let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) .await .expect("Could not connect to server"); - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .expect("Unable to clear cache"); + // Get info from the shard let shard_info = sharded_client .info() .await .expect("Unable to get shard info"); + + // Warmup model + tracing::info!("Warming up model"); + sharded_client + .warmup( + max_input_length as u32, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) + .await + .expect("Unable to warmup model"); tracing::info!("Connected"); // Binds on localhost diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index 148927c1..fc4c0d3a 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -19,11 +19,12 @@ class Cache: def delete(self, batch_id: int): batch = self.pop(batch_id) if batch is not None: - batch.cleanup() + batch.free() del batch def clear(self): - for k in self.cache.keys(): + keys = list(self.cache.keys()) + for k in keys: self.delete(k) def __len__(self): diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ba0853f5..6d47c6eb 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -122,7 +122,7 @@ class CausalLMBatch(Batch): position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) - max_tokens = len(inputs) * max_input_length + max_decode_tokens + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) return cls( batch_id=pb.id, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d57e78c3..02fccd01 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,3 +1,4 @@ +import math import itertools import torch import torch.distributed @@ -5,6 +6,7 @@ import torch.distributed import numpy as np from dataclasses import dataclass +from loguru import logger from opentelemetry import trace from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Dict @@ -21,6 +23,7 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke tracer = trace.get_tracer(__name__) +BLOCK_SIZE = 16 # Will be set in warmup CACHE_MANAGER: Optional["CacheManager"] = None @@ -35,7 +38,7 @@ class CacheManager: dtype: torch.dtype, device: torch.device, ): - self.block_size = 16 + self.block_size = BLOCK_SIZE element_size = torch.tensor([], dtype=dtype).element_size() x = self.block_size // element_size @@ -60,26 +63,30 @@ class CacheManager: 0, num_blocks * self.block_size, dtype=torch.int32 ).view(num_blocks, self.block_size) - def allocate(self, n_tokens: int) -> Tuple[List[int], torch.Tensor]: - # Number of needed block to cover all tokens - needed_blocks = (n_tokens // self.block_size) + 1 - + def allocate(self, num_blocks: int) -> Tuple[torch.Tensor, torch.Tensor]: # Get free blocks indices by finding values in mask that are not set to 0 free_block_indices = self.free_block_mask.nonzero() - assert len(free_block_indices) >= needed_blocks, "Out of available cache blocks" + logger.info(f"Free blocks: {len(free_block_indices)}") + assert ( + len(free_block_indices) >= num_blocks + ), f"Out of available cache blocks: asked {num_blocks}, only {len(free_block_indices)} free blocks" # Allocate the required number of blocks by setting the mask to 0 - block_indices = free_block_indices[:needed_blocks] + block_indices = free_block_indices[:num_blocks] self.free_block_mask[block_indices] = 0 # Get slots for the allocated blocks - slots = self.slots[block_indices].flatten()[:n_tokens] + slots = self.slots[block_indices].flatten() - return block_indices.flatten().tolist(), slots + logger.info(f"allocate {num_blocks} blocks") - def free(self, block_indices: List[int]): - # Reset mask - self.free_block_mask[block_indices] = 1 + return block_indices.flatten(), slots + + def free(self, block_indices: Optional[List[int]]): + if block_indices is not None: + # Reset mask + logger.info(f"free {len(block_indices)} blocks") + self.free_block_mask[block_indices] = 1 @dataclass @@ -97,16 +104,25 @@ class FlashCausalLMBatch(Batch): start_seq_prefill: Optional[torch.Tensor] # tensor of length b holding ending offset of each sequence, only used in prefill end_seq_prefill: Optional[torch.Tensor] - # list of length b of list of length s_i // block_size - block_tables: List[List[int]] - # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences - block_tables_tensor: torch.Tensor + + # Paged Attention values + + # Set when creating the batch # CPU tensor of length b indicating the start of each sequence in slots start_slots: torch.Tensor - # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences - slots: torch.Tensor # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode slot_indices: torch.Tensor + # List of tuple of ints representing the number of blocks and slots needed by each sequence + needed_blocks_slots: Optional[List[Tuple[int, int]]] + + # Set in prefill by the CacheManager + # list of length b of list of length s_i // block_size + block_tables: Optional[List[List[int]]] + # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences + block_tables_tensor: Optional[torch.Tensor] + # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences + slots: Optional[torch.Tensor] + max_seqlen: int # Prefill metadata tensors to efficiently compute logprobs @@ -128,16 +144,17 @@ class FlashCausalLMBatch(Batch): next_token_chooser: HeterogeneousNextTokenChooser stopping_criterias: List[StoppingCriteria] + # Number of blocks in this batch + blocks: int # Maximum number of blocks max_blocks: int def to_pb(self) -> generate_pb2.CachedBatch: - global CACHE_MANAGER return generate_pb2.CachedBatch( id=self.batch_id, request_ids=[r.id for r in self.requests], size=len(self), - max_tokens=len(self.slots), + max_tokens=self.blocks * BLOCK_SIZE, ) @classmethod @@ -148,8 +165,6 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": - global CACHE_MANAGER - batch_inputs = [] max_truncation = 0 for r in pb.requests: @@ -163,9 +178,8 @@ class FlashCausalLMBatch(Batch): position_ids = [] start_seq_prefill = [] end_seq_prefill = [] - block_tables = [] + needed_blocks_slots = [] start_slots = [] - slots = [] slot_indices = [] input_lengths = [] @@ -188,6 +202,7 @@ class FlashCausalLMBatch(Batch): cumulative_max_length = 0 prefill_out_cumulative_length = 0 + blocks = 0 max_seqlen = 0 max_length = 0 max_blocks = 0 @@ -228,9 +243,9 @@ class FlashCausalLMBatch(Batch): # Paged attention # Remove one as the first token des not have a past total_tokens = input_length + max_new_tokens - 1 - request_blocks, request_slots = CACHE_MANAGER.allocate(total_tokens) - block_tables.append(request_blocks) - slots.extend(request_slots) + needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + blocks += needed_blocks + needed_blocks_slots.append((needed_blocks, total_tokens)) start_slots.append(cumulative_max_length) request_slot_indices = torch.arange( @@ -264,7 +279,7 @@ class FlashCausalLMBatch(Batch): cumulative_length += input_length cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) - max_blocks = max(max_blocks, len(request_blocks)) + max_blocks = max(max_blocks, needed_blocks) max_length = max(max_length, input_length + max_new_tokens) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( @@ -272,15 +287,6 @@ class FlashCausalLMBatch(Batch): ) start_slots = torch.tensor(start_slots, dtype=torch.int64) - # Padded block tables - block_tables_tensor = torch.zeros( - (len(pb.requests), max_blocks), dtype=torch.int32 - ) - for i, request_blocks in enumerate(block_tables): - block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) - - block_tables_tensor = block_tables_tensor.to(device) - # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( (len(all_input_ids), max_length), dtype=np.int64 @@ -312,7 +318,6 @@ class FlashCausalLMBatch(Batch): position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - slots = torch.tensor(slots, dtype=torch.int32, device=device) input_lengths_tensor = torch.tensor( input_lengths, dtype=torch.int32, device=device ) @@ -339,11 +344,12 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, start_seq_prefill=start_seq_prefill, end_seq_prefill=end_seq_prefill, - block_tables=block_tables, - block_tables_tensor=block_tables_tensor, start_slots=start_slots, - slots=slots, slot_indices=slot_indices, + needed_blocks_slots=needed_blocks_slots, + block_tables=None, + block_tables_tensor=None, + slots=None, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -356,12 +362,12 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, + blocks=blocks, max_blocks=max_blocks, ) @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": - global CACHE_MANAGER if len(request_ids) == 0: raise ValueError("Batch must have at least one request") # We assume that if len(requests) == len(self) then the requests are the same @@ -396,6 +402,7 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] + blocks = 0 max_blocks = 0 # Cumulative length cumulative_max_length = 0 @@ -425,6 +432,7 @@ class FlashCausalLMBatch(Batch): ) request_block_table = self.block_tables[idx] + blocks += len(request_block_table) block_tables.append(request_block_table) start_slots.append(cumulative_max_length) @@ -443,6 +451,7 @@ class FlashCausalLMBatch(Batch): max_blocks = max(max_blocks, len(request_block_table)) + global CACHE_MANAGER # Iterate on all requests for i, r in enumerate(self.requests): # Filter requests that are not part of the new batch @@ -472,11 +481,12 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, start_seq_prefill=None, end_seq_prefill=None, + start_slots=start_slots, + slot_indices=slot_indices, + needed_blocks_slots=None, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - start_slots=start_slots, slots=slots, - slot_indices=slot_indices, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -489,17 +499,18 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, + blocks=blocks, max_blocks=max_blocks, ) @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": - global CACHE_MANAGER # Batch attributes requests = [] requests_idx_mapping = {} + blocks = 0 total_batch_size = 0 total_slots = 0 max_blocks = 0 @@ -508,6 +519,7 @@ class FlashCausalLMBatch(Batch): for b in batches: total_batch_size += len(b) total_slots += len(b.slots) + blocks += b.blocks max_blocks = max(max_blocks, b.max_blocks) max_seqlen = max(max_seqlen, b.max_seqlen) max_length = max( @@ -613,11 +625,12 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, start_seq_prefill=None, end_seq_prefill=None, + start_slots=start_slots, + slot_indices=slot_indices, + needed_blocks_slots=None, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - start_slots=start_slots, slots=slots, - slot_indices=slot_indices, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -630,13 +643,15 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, + blocks=blocks, max_blocks=max_blocks, ) - def cleanup(self): - global CACHE_MANAGER - # Free blocks - CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables))) + def free(self): + if self.block_tables is not None: + global CACHE_MANAGER + # Free blocks + CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables))) def __len__(self): return len(self.requests) @@ -648,22 +663,17 @@ class FlashCausalLM(Model): model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, num_layers: int, - num_heads: int, + num_kv_heads: int, head_size: int, dtype: torch.dtype, device: torch.device, rank: int = 0, world_size: int = 1, ): - self.num_heads = num_heads + self.num_layers = num_layers + self.num_kv_heads = num_kv_heads self.head_size = head_size - global CACHE_MANAGER - torch.cuda.set_per_process_memory_fraction(1.0) - CACHE_MANAGER = CacheManager( - 1000, num_layers, num_heads, head_size, dtype, device - ) - super(FlashCausalLM, self).__init__( model=model, tokenizer=tokenizer, @@ -678,6 +688,30 @@ class FlashCausalLM(Model): def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch + def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int): + global CACHE_MANAGER + + torch.cuda.empty_cache() + try: + CACHE_MANAGER = CacheManager( + # Adds some wiggle room + math.ceil(max_total_tokens / BLOCK_SIZE) + 10, + self.num_layers, + self.num_kv_heads, + self.head_size, + self.dtype, + self.device, + ) + _, batch = self.generate_token(batch) + except Exception as e: + logger.error( + f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " + f"prefill tokens. " + f"You need to decrease `--max-batch-total-tokens` and `--max-batch-prefill-tokens`" + ) + raise e + batch.free() + def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: return self.tokenizer.decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False @@ -718,6 +752,35 @@ class FlashCausalLM(Model): prefill = batch.start_seq_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None + if batch.needed_blocks_slots: + # Padded block tables + block_tables_tensor = torch.zeros( + (len(batch), batch.max_blocks), dtype=torch.int32 + ) + + # Allocate paged attention blocks + slots = [] + block_tables = [] + try: + for i, (needed_blocks, needed_slots) in enumerate( + batch.needed_blocks_slots + ): + allocated_blocks, allocated_slots = CACHE_MANAGER.allocate( + needed_blocks + ) + slots.append(allocated_slots[:needed_slots]) + block_tables.append(allocated_blocks.tolist()) + block_tables_tensor[i, :needed_blocks] = allocated_blocks + except Exception as e: + for blocks in block_tables: + CACHE_MANAGER.free(blocks) + raise e + + batch.needed_blocks_slots = None + batch.block_tables = block_tables + batch.block_tables_tensor = block_tables_tensor.to(self.device) + batch.slots = torch.concat(slots).to(self.device) + out = self.forward( batch.input_ids, batch.position_ids, @@ -931,7 +994,7 @@ class FlashCausalLM(Model): batch.all_input_ids[i] = all_input_ids if stopped: - batch.cleanup() + batch.free() # No need to return a batch if we know that all requests stopped return generations, None diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 383b8f43..2c59f01e 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -68,7 +68,7 @@ class FlashLlama(FlashCausalLM): model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), - num_heads=model.model.num_heads, + num_kv_heads=model.model.num_heads, head_size=model.model.head_size, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 6b8472a5..f8460fc2 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -22,6 +22,9 @@ class Model(ABC): rank: int = 0, world_size: int = 1, ): + if torch.cuda.is_available(): + torch.cuda.set_per_process_memory_fraction(1.0) + self.model = model.eval() self.tokenizer = tokenizer self.all_special_ids = set(tokenizer.all_special_ids) @@ -55,6 +58,9 @@ class Model(ABC): def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: raise NotImplementedError + def warmup(self, batch: B, max_total_tokens: int): + self.generate_token(batch) + def decode_token( self, all_input_ids: List[int], diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 3ad5698c..999b6637 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -127,7 +127,7 @@ class Seq2SeqLMBatch(Batch): read_offsets.append(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1) - max_tokens = len(inputs) * max_input_length + max_decode_tokens + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) return cls( batch_id=pb.id, diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index bd92022e..c35e15d3 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -35,7 +35,7 @@ class Batch(ABC): def concatenate(cls, batches: List["Batch"]) -> "Batch": raise NotImplementedError - def cleanup(self): + def free(self): pass @abstractmethod diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e1bd8412..378ac841 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -53,12 +53,24 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) + async def Warmup(self, request, context): + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) + self.model.warmup(batch, request.max_total_tokens) + return generate_pb2.WarmupResponse() + async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) - generations, next_batch = self.model.generate_token(batch) + try: + generations, next_batch = self.model.generate_token(batch) + except Exception as e: + batch.free() + raise e + self.cache.set(next_batch) return generate_pb2.PrefillResponse( @@ -81,11 +93,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): raise ValueError("All batches are empty") if len(batches) > 1: - batch = self.model.batch_type.concatenate(batches) + try: + batch = self.model.batch_type.concatenate(batches) + except Exception as e: + [batch.free() for batch in batches] + raise e else: batch = batches[0] - generations, next_batch = self.model.generate_token(batch) + try: + generations, next_batch = self.model.generate_token(batch) + except Exception as e: + batch.free() + raise e + self.cache.set(next_batch) return generate_pb2.DecodeResponse(