diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 2e2bc7a5..0684de0f 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -115,12 +115,6 @@ struct Args { #[clap(default_value = "1512", long, env)] max_total_tokens: usize, - /// The maximum allowed batch size during dynamic batching. - /// Using `max_batch_total_tokens` should be favored in general - /// as it's a finer way to control RAM usage. - #[clap(long, env)] - max_batch_size: Option, - /// This represents the ratio of waiting queries vs running queries where /// you want to start considering pausing the running queries to include the waiting /// ones into the same batch. @@ -134,6 +128,9 @@ struct Args { #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, + #[clap(default_value = "32000", long, env)] + max_batch_prefill_tokens: u32, + /// **IMPORTANT** This is one critical control to allow maximum usage /// of the available hardware. /// @@ -181,7 +178,6 @@ struct Args { #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, #[clap(default_value = "3000", long, short, env)] - /// The port to listen on. port: u16, @@ -329,6 +325,12 @@ fn shard_manager( // Copy current process env let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); + // Use cuda allocator. It leads to less memory fragmentation + env.push(( + "PYTORCH_CUDA_ALLOC_CONF".into(), + "backend:cudaMallocAsync".into(), + )); + // Torch Distributed Env vars env.push(("RANK".into(), rank.to_string().into())); env.push(("WORLD_SIZE".into(), world_size.to_string().into())); @@ -822,6 +824,10 @@ fn spawn_webserver( args.max_input_length.to_string(), "--max-total-tokens".to_string(), args.max_total_tokens.to_string(), + "--max-batch-prefill-tokens".to_string(), + args.max_batch_prefill_tokens.to_string(), + "--max-batch-total-tokens".to_string(), + args.max_batch_total_tokens.to_string(), "--waiting-served-ratio".to_string(), args.waiting_served_ratio.to_string(), "--max-waiting-tokens".to_string(), @@ -834,15 +840,6 @@ fn spawn_webserver( args.model_id, ]; - // Deprecate max_batch_size - if let Some(max_batch_size) = args.max_batch_size { - argv.push("--max-batch-size".to_string()); - argv.push(max_batch_size.to_string()) - } else { - argv.push("--max-batch-total-tokens".to_string()); - argv.push(args.max_batch_total_tokens.to_string()) - } - // Model optional revision if let Some(ref revision) = args.revision { argv.push("--revision".to_string()); diff --git a/router/src/infer.rs b/router/src/infer.rs index f738f986..255932ef 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -45,6 +45,7 @@ impl Infer { client: ShardedClient, validation: Validation, waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, max_concurrent_requests: usize, @@ -61,6 +62,7 @@ impl Infer { tokio::spawn(batching_task( client, waiting_served_ratio, + max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, queue.clone(), @@ -243,6 +245,7 @@ impl Infer { async fn batching_task( mut client: ShardedClient, waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, queue: Queue, @@ -257,8 +260,9 @@ async fn batching_task( // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the queue - while let Some((mut entries, batch, span)) = - queue.next_batch(None, max_batch_total_tokens).await + while let Some((mut entries, batch, span)) = queue + .next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens) + .await { let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) .instrument(span) @@ -287,8 +291,9 @@ async fn batching_task( let token_budget = max_batch_total_tokens - batch_max_tokens; // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = - queue.next_batch(min_size, token_budget).await + if let Some((mut new_entries, new_batch, span)) = queue + .next_batch(min_size, max_batch_prefill_tokens, token_budget) + .await { // Tracking metrics if min_size.is_some() { diff --git a/router/src/main.rs b/router/src/main.rs index 7bbb6477..4c9a6740 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -32,11 +32,11 @@ struct Args { max_input_length: usize, #[clap(default_value = "1512", long, env)] max_total_tokens: usize, - #[clap(long, env)] - max_batch_size: Option, #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, #[clap(default_value = "32000", long, env)] + max_batch_prefill_tokens: u32, + #[clap(default_value = "32000", long, env)] max_batch_total_tokens: u32, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, @@ -78,9 +78,9 @@ fn main() -> Result<(), std::io::Error> { max_stop_sequences, max_input_length, max_total_tokens, - max_batch_size, waiting_served_ratio, - mut max_batch_total_tokens, + max_batch_prefill_tokens, + max_batch_total_tokens, max_waiting_tokens, port, master_shard_uds_path, @@ -141,12 +141,6 @@ fn main() -> Result<(), std::io::Error> { .block_on(async { init_logging(otlp_endpoint, json_output); - if let Some(max_batch_size) = max_batch_size { - tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead"); - max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32; - tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}"); - } - if tokenizer.is_none() { tracing::warn!( "Could not find a fast tokenizer implementation for {tokenizer_name}" @@ -161,10 +155,16 @@ fn main() -> Result<(), std::io::Error> { sha: None, pipeline_tag: None, }, - false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| { - tracing::warn!("Could not retrieve model info from the Hugging Face hub."); - HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None } - }), + false => get_model_info(&tokenizer_name, &revision, authorization_token) + .await + .unwrap_or_else(|| { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + } + }), }; // if pipeline-tag == text-generation we default to return_full_text = true @@ -206,6 +206,7 @@ fn main() -> Result<(), std::io::Error> { max_input_length, max_total_tokens, waiting_served_ratio, + max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, sharded_client, @@ -219,7 +220,7 @@ fn main() -> Result<(), std::io::Error> { ngrok_username, ngrok_password, ) - .await; + .await; Ok(()) }) } diff --git a/router/src/queue.rs b/router/src/queue.rs index 6d1d4d12..a3a607e7 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -58,6 +58,7 @@ impl Queue { pub(crate) async fn next_batch( &self, min_size: Option, + prefill_token_budget: u32, token_budget: u32, ) -> Option { // Create response channel @@ -67,6 +68,7 @@ impl Queue { self.queue_sender .send(QueueCommand::NextBatch { min_size, + prefill_token_budget, token_budget, response_sender, span: Span::current(), @@ -90,11 +92,12 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver span.in_scope(|| { - let next_batch = state.next_batch(min_size, token_budget); + let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); metrics::gauge!("tgi_queue_size", state.entries.len() as f64); }), @@ -140,7 +143,12 @@ impl State { } // Get the next batch - fn next_batch(&mut self, min_size: Option, token_budget: u32) -> Option { + fn next_batch( + &mut self, + min_size: Option, + prefill_token_budget: u32, + token_budget: u32, + ) -> Option { if self.entries.is_empty() { return None; } @@ -184,7 +192,9 @@ impl State { decode_tokens += entry.request.stopping_parameters.max_new_tokens; - if (prefill_tokens + decode_tokens) > token_budget { + if prefill_tokens > prefill_token_budget + || (prefill_tokens + decode_tokens) > token_budget + { // Entry is over budget // Add it back to the front self.entries.push_front((id, entry)); @@ -259,6 +269,7 @@ enum QueueCommand { Append(Box, Span), NextBatch { min_size: Option, + prefill_token_budget: u32, token_budget: u32, response_sender: oneshot::Sender>, span: Span, @@ -294,7 +305,7 @@ mod tests { watermark: false, }, stopping_parameters: StoppingCriteriaParameters { - ignore_eos_token: false, + ignore_eos_token: true, max_new_tokens: 1, stop_sequences: vec![], }, diff --git a/router/src/server.rs b/router/src/server.rs index b8c67b2c..04d1269b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -152,7 +152,7 @@ async fn generate( let start_time = Instant::now(); metrics::increment_counter!("tgi_request_count"); - tracing::debug!("Input: {}", req.0.inputs); + // tracing::debug!("Input: {}", req.0.inputs); let compute_characters = req.0.inputs.chars().count(); let mut add_prompt = None; @@ -286,7 +286,7 @@ async fn generate( } tracing::debug!("Output: {}", output_text); - tracing::info!("Success"); + // tracing::info!("Success"); let response = GenerateResponse { generated_text: output_text, @@ -513,6 +513,7 @@ pub async fn run( max_input_length: usize, max_total_tokens: usize, waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, client: ShardedClient, @@ -581,6 +582,7 @@ pub async fn run( client, validation, waiting_served_ratio, + max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_concurrent_requests, diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index 5556529c..148927c1 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -19,10 +19,12 @@ class Cache: def delete(self, batch_id: int): batch = self.pop(batch_id) if batch is not None: + batch.cleanup() del batch def clear(self): - self.cache.clear() + for k in self.cache.keys(): + self.delete(k) def __len__(self): return len(self.cache.keys()) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 993e1e2a..bb7fcbef 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -23,7 +23,9 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from typing import Optional +from typing import Optional, List, Tuple +from vllm import attention_ops +from vllm import cache_ops # Flash attention imports import flash_attn_cuda @@ -106,7 +108,7 @@ class FlashLlamaAttention(torch.nn.Module): prefix=f"{prefix}.rotary_emb", weights=weights ) - self.softmax_scale = self.head_size ** (-0.5) + self.softmax_scale = self.head_size**-0.5 self.num_heads = self.num_heads // weights.process_group.size() self.query_key_value = TensorParallelColumnLinear.load_multi( @@ -128,14 +130,13 @@ class FlashLlamaAttention(torch.nn.Module): hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -144,23 +145,25 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin) - # Prefill - if prefill: - # Copy to layer past - layer_past[...] = qkv[:, 1:] + cache_ops.reshape_and_cache( + qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots + ) - # output - attn_output = torch.empty_like(qkv[:, 0]) + # output tensor + attn_output = torch.empty_like(qkv[:, 0]) + + # Prefill + if start_seq_prefill is not None: # flash attention flash_attn_cuda.fwd( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -173,31 +176,18 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - query = qkv[:, 0] - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = qkv[:, 1:] - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - layer_past[:, 0], - layer_past[:, 1], + # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] + block_size = kv_cache[1].shape[3] + attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + qkv[:, 0], + kv_cache[0], + kv_cache[1], self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -265,14 +255,13 @@ class FlashLlamaLayer(nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -281,14 +270,13 @@ class FlashLlamaLayer(nn.Module): normed_hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) # faster post attention rms norm @@ -333,40 +321,18 @@ class FlashLlamaModel(torch.nn.Module): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, - ): + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) - # Prefill - if past_key_values is None: - assert pre_allocate_past_size is not None - - prefill = True - - # Create past tensor - # We create a tensor of the same size as input_ids as we don't want to slice at every layer - past_key_values = hidden_states.new_empty( - ( - len(input_ids), - len(self.layers), - 2, - self.num_heads, - self.head_size, - ) - ) - # Decode - else: - prefill = False - # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( @@ -380,34 +346,18 @@ class FlashLlamaModel(torch.nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, max_s, - past_key_values[:, i], - past_present_indices, - prefill, ) - if prefill: - present = past_key_values - # Create padded past tensor - past_key_values = hidden_states.new_empty( - ( - pre_allocate_past_size, - len(self.layers), - 2, - self.num_heads, - self.head_size, - ) - ) - # We slice only once instead of at every layer - past_key_values[past_present_indices] = present - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states, past_key_values + return hidden_states class FlashLlamaForCausalLM(torch.nn.Module): @@ -423,31 +373,29 @@ class FlashLlamaForCausalLM(torch.nn.Module): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ): - hidden_states, present = self.model( + ) -> torch.Tensor: + hidden_states = self.model( input_ids, position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - past_present_indices, - past_key_values, - pre_allocate_past_size, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - return logits, present + return logits diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 12679e9d..19deca86 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -1004,7 +1004,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel): try: self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) except RuntimeError: - self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights) + self.shared = TensorParallelEmbedding( + prefix="encoder.embed_tokens", weights=weights + ) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index ecea998e..d57e78c3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,3 +1,4 @@ +import itertools import torch import torch.distributed @@ -5,7 +6,7 @@ import numpy as np from dataclasses import dataclass from opentelemetry import trace -from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel +from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Dict from text_generation_server.models import Model @@ -20,6 +21,66 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke tracer = trace.get_tracer(__name__) +# Will be set in warmup +CACHE_MANAGER: Optional["CacheManager"] = None + + +class CacheManager: + def __init__( + self, + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + self.block_size = 16 + + element_size = torch.tensor([], dtype=dtype).element_size() + x = self.block_size // element_size + + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, self.block_size, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, self.block_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") + self.slots = torch.arange( + 0, num_blocks * self.block_size, dtype=torch.int32 + ).view(num_blocks, self.block_size) + + def allocate(self, n_tokens: int) -> Tuple[List[int], torch.Tensor]: + # Number of needed block to cover all tokens + needed_blocks = (n_tokens // self.block_size) + 1 + + # Get free blocks indices by finding values in mask that are not set to 0 + free_block_indices = self.free_block_mask.nonzero() + assert len(free_block_indices) >= needed_blocks, "Out of available cache blocks" + + # Allocate the required number of blocks by setting the mask to 0 + block_indices = free_block_indices[:needed_blocks] + self.free_block_mask[block_indices] = 0 + + # Get slots for the allocated blocks + slots = self.slots[block_indices].flatten()[:n_tokens] + + return block_indices.flatten().tolist(), slots + + def free(self, block_indices: List[int]): + # Reset mask + self.free_block_mask[block_indices] = 1 + @dataclass class FlashCausalLMBatch(Batch): @@ -32,23 +93,20 @@ class FlashCausalLMBatch(Batch): input_ids: torch.Tensor position_ids: torch.Tensor - # Indices to copy present to the correct indices is the pre-allocated past key values - past_present_indices: torch.Tensor - - # tensor of length b holding starting offset of each sequence - start_seq: torch.Tensor - # tensor of length b holding ending offset of each sequence - end_seq: torch.Tensor # tensor of length b holding starting offset of each sequence, only used in prefill start_seq_prefill: Optional[torch.Tensor] # tensor of length b holding ending offset of each sequence, only used in prefill end_seq_prefill: Optional[torch.Tensor] - # tensor of length b holding starting offset of each query sequence, only used in decode - start_seq_q: Optional[torch.Tensor] - # tensor of length b holding ending offset of each query sequence, only used in decode - end_seq_q: Optional[torch.Tensor] - # past key values, only used in decode - past_key_values: Optional[torch.Tensor] + # list of length b of list of length s_i // block_size + block_tables: List[List[int]] + # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences + block_tables_tensor: torch.Tensor + # CPU tensor of length b indicating the start of each sequence in slots + start_slots: torch.Tensor + # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences + slots: torch.Tensor + # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode + slot_indices: torch.Tensor max_seqlen: int # Prefill metadata tensors to efficiently compute logprobs @@ -62,6 +120,7 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] + input_lengths_tensor: torch.Tensor prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -69,15 +128,16 @@ class FlashCausalLMBatch(Batch): next_token_chooser: HeterogeneousNextTokenChooser stopping_criterias: List[StoppingCriteria] - # Maximum number of tokens this batch will grow to - max_tokens: int + # Maximum number of blocks + max_blocks: int def to_pb(self) -> generate_pb2.CachedBatch: + global CACHE_MANAGER return generate_pb2.CachedBatch( id=self.batch_id, request_ids=[r.id for r in self.requests], size=len(self), - max_tokens=self.max_tokens, + max_tokens=len(self.slots), ) @classmethod @@ -88,6 +148,8 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": + global CACHE_MANAGER + batch_inputs = [] max_truncation = 0 for r in pb.requests: @@ -99,12 +161,12 @@ class FlashCausalLMBatch(Batch): )["input_ids"] position_ids = [] - past_present_indices = [] - start_seq = [] - end_seq = [] start_seq_prefill = [] end_seq_prefill = [] - max_seqlen = 0 + block_tables = [] + start_slots = [] + slots = [] + slot_indices = [] input_lengths = [] prefix_offsets = [] @@ -126,7 +188,9 @@ class FlashCausalLMBatch(Batch): cumulative_max_length = 0 prefill_out_cumulative_length = 0 + max_seqlen = 0 max_length = 0 + max_blocks = 0 # Parse batch for i, (r, tokenized_input) in enumerate( @@ -138,7 +202,6 @@ class FlashCausalLMBatch(Batch): tokenized_input = tokenized_input[-r.truncate :] input_length = len(tokenized_input) - max_seqlen = max(max_seqlen, input_length) input_lengths.append(input_length) prefix_offsets.append(input_length - 5) @@ -153,8 +216,6 @@ class FlashCausalLMBatch(Batch): # Add cumulative lengths of all previous inputs start_seq_prefill.append(cumulative_length) end_seq_prefill.append(cumulative_length + input_length) - start_seq.append(cumulative_max_length) - end_seq.append(cumulative_max_length + input_length) next_token_chooser_parameters.append(r.parameters) @@ -164,6 +225,21 @@ class FlashCausalLMBatch(Batch): max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) + # Paged attention + # Remove one as the first token des not have a past + total_tokens = input_length + max_new_tokens - 1 + request_blocks, request_slots = CACHE_MANAGER.allocate(total_tokens) + block_tables.append(request_blocks) + slots.extend(request_slots) + start_slots.append(cumulative_max_length) + + request_slot_indices = torch.arange( + cumulative_max_length, + cumulative_max_length + input_length, + dtype=torch.int64, + ) + slot_indices.append(request_slot_indices) + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs @@ -184,22 +260,26 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - request_past_present_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, - dtype=torch.int64, - ) - past_present_indices.append(request_past_present_indices) - # Update - # Remove one as the first token des not have a past cumulative_length += input_length - cumulative_max_length += input_length + max_new_tokens - 1 + cumulative_max_length += total_tokens + max_seqlen = max(max_seqlen, input_length) + max_blocks = max(max_blocks, len(request_blocks)) max_length = max(max_length, input_length + max_new_tokens) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device ) + start_slots = torch.tensor(start_slots, dtype=torch.int64) + + # Padded block tables + block_tables_tensor = torch.zeros( + (len(pb.requests), max_blocks), dtype=torch.int32 + ) + for i, request_blocks in enumerate(block_tables): + block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) + + block_tables_tensor = block_tables_tensor.to(device) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( @@ -212,34 +292,29 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor = torch.tensor( all_input_ids_tensor, dtype=torch.int64, device=device ) - start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32) - end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32) if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) - - past_present_indices = np.concatenate(past_present_indices, dtype=np.int64) - - start_seq_prefill = torch.tensor( - start_seq_prefill, device=device, dtype=torch.int32 - ) - end_seq_prefill = torch.tensor( - end_seq_prefill, device=device, dtype=torch.int32 - ) + slot_indices = torch.cat(slot_indices) else: input_ids = all_input_ids[0] position_ids = position_ids[0] + slot_indices = slot_indices[0] - past_present_indices = past_present_indices[0] - - start_seq_prefill = start_seq - end_seq_prefill = end_seq + start_seq_prefill = torch.tensor( + start_seq_prefill, device=device, dtype=torch.int32 + ) + end_seq_prefill = torch.tensor( + end_seq_prefill, device=device, dtype=torch.int32 + ) + position_ids = position_ids.to(device) + slot_indices = slot_indices.to(device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) - past_present_indices = torch.tensor( - past_present_indices, device=device, dtype=torch.int64 + slots = torch.tensor(slots, dtype=torch.int32, device=device) + input_lengths_tensor = torch.tensor( + input_lengths, dtype=torch.int32, device=device ) if all_prefill_logprobs: @@ -262,30 +337,31 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - past_present_indices=past_present_indices, - start_seq=start_seq, - end_seq=end_seq, start_seq_prefill=start_seq_prefill, end_seq_prefill=end_seq_prefill, - start_seq_q=None, - end_seq_q=None, + block_tables=block_tables, + block_tables_tensor=block_tables_tensor, + start_slots=start_slots, + slots=slots, + slot_indices=slot_indices, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, - past_key_values=None, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, - max_tokens=cumulative_max_length, + max_blocks=max_blocks, ) @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": + global CACHE_MANAGER if len(request_ids) == 0: raise ValueError("Batch must have at least one request") # We assume that if len(requests) == len(self) then the requests are the same @@ -294,28 +370,24 @@ class FlashCausalLMBatch(Batch): device = self.input_ids.device - # Cumulative length - cumulative_max_length = 0 - # New values after filtering requests_idx_mapping = {} # Used to index into tensors indices = [] - # past indices to keep - past_indices = torch.zeros( - self.past_key_values.shape[0], dtype=torch.bool, device=device + # slots to keep after filtering + slot_filtering_indices = torch.zeros( + self.slots.shape[0], dtype=torch.bool, device=device ) # Create on CPU to only move to GPU once instead of at every copy - start_seq = torch.empty(len(request_ids), dtype=torch.int32) - end_seq = torch.empty(len(request_ids), dtype=torch.int32) - start_seq_q = self.start_seq_q[: len(request_ids)] - end_seq_q = self.end_seq_q[: len(request_ids)] + slot_indices = torch.empty(len(request_ids), dtype=torch.int64) max_seqlen = 0 requests = [] + start_slots = [] + block_tables = [] all_input_ids = [] input_lengths = [] @@ -324,6 +396,10 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] + max_blocks = 0 + # Cumulative length + cumulative_max_length = 0 + for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] indices.append(idx) @@ -348,28 +424,45 @@ class FlashCausalLMBatch(Batch): stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) + request_block_table = self.block_tables[idx] + block_tables.append(request_block_table) + start_slots.append(cumulative_max_length) + # Copy to tensor (CPU) - start_seq[i] = cumulative_max_length - end_seq[i] = cumulative_max_length + request_input_length + slot_indices[i] = cumulative_max_length + request_input_length - 1 # Set slice - past_indices[ - self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1 + slot_filtering_indices[ + self.start_slots[idx] : self.start_slots[idx] + + request_input_length + + remaining_tokens + - 1 ] = True cumulative_max_length += request_input_length + remaining_tokens - 1 + max_blocks = max(max_blocks, len(request_block_table)) + + # Iterate on all requests + for i, r in enumerate(self.requests): + # Filter requests that are not part of the new batch + if r.id not in requests_idx_mapping.keys(): + # Free blocks + CACHE_MANAGER.free(self.block_tables[i]) + # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] + block_tables_tensor = self.block_tables_tensor[indices] + input_lengths_tensor = self.input_lengths_tensor[indices] + slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) - past_key_values = self.past_key_values[past_indices] + + start_slots = torch.tensor(start_slots, dtype=torch.int64) # Move to GPU now that we have the whole tensor - start_seq = start_seq.to(device) - end_seq = end_seq.to(device) - past_present_indices = end_seq - 1 + slot_indices = slot_indices.to(device) return FlashCausalLMBatch( batch_id=self.batch_id, @@ -377,51 +470,74 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - past_present_indices=past_present_indices, - start_seq=start_seq, - end_seq=end_seq, start_seq_prefill=None, end_seq_prefill=None, - start_seq_q=start_seq_q, - end_seq_q=end_seq_q, + block_tables=block_tables, + block_tables_tensor=block_tables_tensor, + start_slots=start_slots, + slots=slots, + slot_indices=slot_indices, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - past_key_values=past_key_values, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, - max_tokens=cumulative_max_length, + max_blocks=max_blocks, ) @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": + global CACHE_MANAGER # Batch attributes requests = [] requests_idx_mapping = {} - total_batch_size = sum([len(b) for b in batches]) - - dtype = batches[0].past_key_values.dtype - device = batches[0].input_ids.device + total_batch_size = 0 + total_slots = 0 + max_blocks = 0 + max_length = 0 + max_seqlen = 0 + for b in batches: + total_batch_size += len(b) + total_slots += len(b.slots) + max_blocks = max(max_blocks, b.max_blocks) + max_seqlen = max(max_seqlen, b.max_seqlen) + max_length = max( + max_length, + max( + input_length + + stopping_criteria.max_new_tokens + - stopping_criteria.current_tokens + for input_length, stopping_criteria in zip( + b.input_lengths, b.stopping_criterias + ) + ), + ) input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - start_seq = batches[0].start_seq.new_empty(total_batch_size) - end_seq = batches[0].end_seq.new_empty(total_batch_size) - start_seq_q = torch.arange( - 0, total_batch_size, device=device, dtype=torch.int32 + slots = batches[0].slots.new_empty(total_slots) + slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + total_batch_size + ) + block_tables_tensor = batches[0].block_tables_tensor.new_zeros( + (total_batch_size, max_blocks) + ) + all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( + (total_batch_size, max_length) ) - end_seq_q = start_seq_q + 1 - max_seqlen = 0 - past_key_values = [] + start_slots = [] + block_tables = [] all_input_ids = [] input_lengths = [] @@ -433,8 +549,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_batch_size = 0 - max_tokens = 0 - max_length = 0 + cumulative_slots = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) @@ -448,16 +563,27 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids + slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots + input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor + slots[slots_start_index:slots_end_index] = batch.slots - start_seq[start_index:end_index] = batch.start_seq + max_tokens - end_seq[start_index:end_index] = batch.end_seq + max_tokens + all_input_ids_tensor[ + start_index:end_index, : batch.all_input_ids_tensor.shape[1] + ] = batch.all_input_ids_tensor[:, :max_length] - max_seqlen = max(max_seqlen, batch.max_seqlen) + block_tables_tensor[ + start_index:end_index, : batch.block_tables_tensor.shape[1] + ] = batch.block_tables_tensor[:, :max_blocks] + start_slots.append(batch.start_slots + cumulative_slots) + + block_tables.extend(batch.block_tables) all_input_ids.extend(batch.all_input_ids) input_lengths.extend(batch.input_lengths) @@ -466,43 +592,17 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) stopping_criterias.extend(batch.stopping_criterias) - past_key_values.append(batch.past_key_values) # Update cumulative_batch_size += len(batch) - max_tokens += batch.max_tokens - max_length = max( - max_length, - max( - input_length - + stopping_criteria.max_new_tokens - - stopping_criteria.current_tokens - for input_length, stopping_criteria in zip( - batch.input_lengths, batch.stopping_criterias - ) - ), - ) + cumulative_slots += len(batch.slots) - past_key_values = torch.cat(past_key_values, dim=0) - past_present_indices = end_seq - 1 - - all_input_ids_tensor = torch.zeros( - (total_batch_size, max_length), dtype=torch.int64, device=device - ) - - cumulative_batch_size = 0 - for i, batch in enumerate(batches): - start_index = cumulative_batch_size - end_index = cumulative_batch_size + len(batch) - - all_input_ids_tensor[ - start_index:end_index, : batch.all_input_ids_tensor.shape[1] - ] = batch.all_input_ids_tensor[:, :max_length] - - cumulative_batch_size += len(batch) + start_slots = torch.concat(start_slots) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype=dtype, device=device + next_token_chooser_parameters, + dtype=batches[0].next_token_chooser.dtype, + device=batches[0].next_token_chooser.device, ) return FlashCausalLMBatch( @@ -511,28 +611,33 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - past_present_indices=past_present_indices, - start_seq=start_seq, - end_seq=end_seq, start_seq_prefill=None, end_seq_prefill=None, - start_seq_q=start_seq_q, - end_seq_q=end_seq_q, + block_tables=block_tables, + block_tables_tensor=block_tables_tensor, + start_slots=start_slots, + slots=slots, + slot_indices=slot_indices, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - past_key_values=past_key_values, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, - max_tokens=max_tokens, + max_blocks=max_blocks, ) + def cleanup(self): + global CACHE_MANAGER + # Free blocks + CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables))) + def __len__(self): return len(self.requests) @@ -540,32 +645,24 @@ class FlashCausalLMBatch(Batch): class FlashCausalLM(Model): def __init__( self, - model_cls: Type[PreTrainedModel], - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, + model: torch.nn.Module, + tokenizer: PreTrainedTokenizerBase, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + rank: int = 0, + world_size: int = 1, ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 - else: - raise NotImplementedError("FlashCausalLM is only available on GPU") + self.num_heads = num_heads + self.head_size = head_size - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, + global CACHE_MANAGER + torch.cuda.set_per_process_memory_fraction(1.0) + CACHE_MANAGER = CacheManager( + 1000, num_layers, num_heads, head_size, dtype, device ) - model = model_cls.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ).to(device) super(FlashCausalLM, self).__init__( model=model, @@ -573,6 +670,8 @@ class FlashCausalLM(Model): requires_padding=False, dtype=dtype, device=device, + rank=rank, + world_size=world_size, ) @property @@ -588,28 +687,27 @@ class FlashCausalLM(Model): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - start_seq: torch.Tensor, - end_seq: torch.Tensor, - start_seq_q: Optional[torch.Tensor], - end_seq_q: Optional[torch.Tensor], + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, max_s: int, - past_present_indices: torch.Tensor, - past_key_values: Optional = None, - pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + global CACHE_MANAGER + # Model Forward return self.model.forward( input_ids=input_ids, position_ids=position_ids, - start_seq=start_seq, - end_seq=end_seq, - start_seq_q=start_seq_q, - end_seq_q=end_seq_q, + start_seq_prefill=start_seq_prefill, + end_seq_prefill=end_seq_prefill, + kv_cache=CACHE_MANAGER.kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, max_s=max_s, - past_present_indices=past_present_indices, - past_key_values=past_key_values, - pre_allocate_past_size=pre_allocate_past_size, lm_head_indices=lm_head_indices, ) @@ -617,31 +715,18 @@ class FlashCausalLM(Model): def generate_token( self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: - prefill = batch.past_key_values is None + prefill = batch.start_seq_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None - if prefill: - # Ask to pre-allocate kv to its max size - # == Sum over batch size (number of tokens + max_new_tokens) - batch size - pre_allocate_past_size = batch.max_tokens - start_seq = batch.start_seq_prefill - end_seq = batch.end_seq_prefill - else: - pre_allocate_past_size = None - start_seq = batch.start_seq - end_seq = batch.end_seq - - out, present = self.forward( + out = self.forward( batch.input_ids, batch.position_ids, - start_seq, - end_seq, - batch.start_seq_q, - batch.end_seq_q, + batch.start_seq_prefill, + batch.end_seq_prefill, + batch.block_tables_tensor, + batch.slots[batch.slot_indices], + batch.input_lengths_tensor, batch.max_seqlen, - batch.past_present_indices, - batch.past_key_values, - pre_allocate_past_size, batch.prefill_head_indices, ) @@ -662,12 +747,8 @@ class FlashCausalLM(Model): # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - # Create batch.start_seq_q and batch.end_seq_q for decode - batch.start_seq_q = torch.arange( - 0, len(batch), device=self.device, dtype=torch.int32 - ) - batch.end_seq_q = batch.start_seq_q + 1 next_position_ids = batch.position_ids.new_empty(len(batch)) + batch.slot_indices = batch.slot_indices[batch.end_seq_prefill - 1] # We do not need start_seq_prefill and end_seq_prefill anymore batch.start_seq_prefill = None batch.end_seq_prefill = None @@ -731,8 +812,8 @@ class FlashCausalLM(Model): # Set values in batch batch.input_ids = next_input_ids batch.position_ids = next_position_ids + 1 - batch.past_present_indices = batch.end_seq - batch.end_seq = batch.end_seq + 1 + batch.input_lengths_tensor += 1 + batch.slot_indices += 1 if prefill and prefill_logprobs: # Get prefill logprobs @@ -755,7 +836,6 @@ class FlashCausalLM(Model): batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, - batch.all_input_ids_tensor, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, next_token_ids, @@ -770,7 +850,6 @@ class FlashCausalLM(Model): read_offset, stopping_criteria, all_input_ids, - all_input_ids_tensor, do_sample, seed, next_token_id, @@ -845,19 +924,20 @@ class FlashCausalLM(Model): generations.append(generation) - new_input_length = input_length + 1 - # Update values - batch.input_lengths[i] = new_input_length + batch.input_lengths[i] = input_length + 1 batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids + if stopped: + batch.cleanup() + # No need to return a batch if we know that all requests stopped + return generations, None + batch.prefill_cu_outlens = None batch.prefill_head_indices = None batch.prefill_next_token_indices = None batch.max_seqlen = batch.max_seqlen + 1 - batch.past_key_values = present - # No need to return a batch if we know that all requests stopped - return generations, batch if not stopped else None + return generations, batch diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index a80d58cb..383b8f43 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -64,10 +64,12 @@ class FlashLlama(FlashCausalLM): model = FlashLlamaForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashCausalLM, self).__init__( + super(FlashLlama, self).__init__( model=model, tokenizer=tokenizer, - requires_padding=False, + num_layers=len(model.model.layers), + num_heads=model.model.num_heads, + head_size=model.model.head_size, dtype=dtype, device=device, rank=rank, diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index a71c0061..f4363e19 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -52,8 +52,11 @@ class FlashSantacoderSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group, - aliases = {"transformer.wte.weight": ["lm_head.weight"]} + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases={"transformer.wte.weight": ["lm_head.weight"]}, ) model = FlashSantacoderForCausalLM(config, weights) diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 28ca8147..bd92022e 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -35,6 +35,9 @@ class Batch(ABC): def concatenate(cls, batches: List["Batch"]) -> "Batch": raise NotImplementedError + def cleanup(self): + pass + @abstractmethod def __len__(self): raise NotImplementedError diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index e6e512bc..b83af591 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -216,6 +216,8 @@ class HeterogeneousNextTokenChooser: self.seeds = seeds self.do_sample = do_sample + self.dtype = dtype + self.device = device def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): if self.watermark_processor is not None: diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 9d371834..83d9df68 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -5,7 +5,14 @@ import torch class Weights: - def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None): + def __init__( + self, + filenames: List[Path], + device, + dtype, + process_group, + aliases: Optional[Dict[str, List[str]]] = None, + ): routing = {} for filename in filenames: with safe_open(filename, framework="pytorch") as f: @@ -43,7 +50,7 @@ class Weights: return str(filename), tensor_name def _get_slice(self, tensor_name: str): - filename, tensor_name= self.get_filename(tensor_name) + filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) return slice_ @@ -94,12 +101,20 @@ class Weights: def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): if quantize == "gptq": try: - qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) + qweight = torch.cat( + [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) except RuntimeError: - raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) - qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1) - scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1) + qzeros = torch.cat( + [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + scales = torch.cat( + [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) @@ -118,7 +133,9 @@ class Weights: try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: - raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)