diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml index a0f1d6f1..4d0b19a3 100644 --- a/.github/workflows/build_documentation.yml +++ b/.github/workflows/build_documentation.yml @@ -17,5 +17,4 @@ jobs: package: text-generation-inference additional_args: --not_python_module secrets: - token: ${{ secrets.HUGGINGFACE_PUSH }} - hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} \ No newline at end of file + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} diff --git a/Cargo.lock b/Cargo.lock index 8fa7b726..b1f7279a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -743,18 +743,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853" -[[package]] -name = "flume" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" -dependencies = [ - "futures-core", - "futures-sink", - "nanorand", - "spin 0.9.8", -] - [[package]] name = "fnv" version = "1.0.7" @@ -900,10 +888,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" dependencies = [ "cfg-if", - "js-sys", "libc", "wasi", - "wasm-bindgen", ] [[package]] @@ -1508,15 +1494,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "nanorand" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" -dependencies = [ - "getrandom", -] - [[package]] name = "native-tls" version = "0.2.11" @@ -2313,7 +2290,7 @@ dependencies = [ "cc", "libc", "once_cell", - "spin 0.5.2", + "spin", "untrusted", "web-sys", "winapi", @@ -2678,15 +2655,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" -dependencies = [ - "lock_api", -] - [[package]] name = "spm_precompiled" version = "0.1.4" @@ -2808,7 +2776,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "1.1.0" +version = "1.1.1" dependencies = [ "average", "clap", @@ -2829,7 +2797,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "1.1.0" +version = "1.1.1" dependencies = [ "futures", "grpc-metadata", @@ -2845,7 +2813,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "1.1.0" +version = "1.1.1" dependencies = [ "clap", "ctrlc", @@ -2861,13 +2829,12 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "1.1.0" +version = "1.1.1" dependencies = [ "async-stream", "axum", "axum-tracing-opentelemetry", "clap", - "flume", "futures", "hf-hub 0.3.1", "init-tracing-opentelemetry", @@ -2885,6 +2852,7 @@ dependencies = [ "thiserror", "tokenizers", "tokio", + "tokio-stream", "tower-http", "tracing", "tracing-opentelemetry", diff --git a/docs/source/basic_tutorials/preparing_model.md b/docs/source/basic_tutorials/preparing_model.md index 97c9bbe0..56124a3b 100644 --- a/docs/source/basic_tutorials/preparing_model.md +++ b/docs/source/basic_tutorials/preparing_model.md @@ -4,7 +4,7 @@ Text Generation Inference improves the model in several aspects. ## Quantization -TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./conceptual/quantization.md) +TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization.md) ## RoPE Scaling diff --git a/router/Cargo.toml b/router/Cargo.toml index 87b5a8d3..55af635a 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -20,7 +20,6 @@ axum = { version = "0.6.20", features = ["json"] } axum-tracing-opentelemetry = "0.14.1" text-generation-client = { path = "client" } clap = { version = "4.4.5", features = ["derive", "env"] } -flume = "0.11.0" futures = "0.3.28" metrics = "0.21.1" metrics-exporter-prometheus = { version = "0.12.1", features = [] } @@ -34,6 +33,7 @@ serde_json = "1.0.107" thiserror = "1.0.48" tokenizers = { version = "0.14.0", features = ["http"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.14" tower-http = { version = "0.4.4", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.21.0" diff --git a/router/client/src/client.rs b/router/client/src/client.rs index d427d3a4..341e70fd 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -103,17 +103,18 @@ impl Client { &mut self, max_input_length: u32, max_prefill_tokens: u32, + max_total_tokens: u32, ) -> Result> { let mut n_tokens = 0; let mut requests = Vec::new(); - // Create requests while n_tokens < max_prefill_tokens { + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); requests.push(Request { id: 0, // We truncate the input on the server side to be sure that it has the correct size inputs: "_test ".to_string().repeat(max_input_length as usize), - truncate: min(max_input_length, max_prefill_tokens - n_tokens), + truncate, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, @@ -126,9 +127,9 @@ impl Client { watermark: true, }), stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: 2, + max_new_tokens: max_total_tokens - truncate, stop_sequences: vec![], - ignore_eos_token: false, + ignore_eos_token: true, }), prefill_logprobs: true, top_n_tokens: 20, diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 112b0035..b4bdcd42 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -95,11 +95,14 @@ impl ShardedClient { &mut self, max_input_length: u32, max_prefill_tokens: u32, + max_total_tokens: u32, ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens))) + .map(|client| { + Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens)) + }) .collect(); // Take the minimum value let results = join_all(futures) diff --git a/router/src/infer.rs b/router/src/infer.rs index 787ccfcf..cc34c466 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -2,22 +2,21 @@ use crate::validation::{Validation, ValidationError}; use crate::{Entry, Queue, Token}; use crate::{GenerateRequest, PrefillToken}; -use flume::r#async::RecvStream; -use flume::SendTimeoutError; use futures::future::try_join_all; -use futures::stream::StreamExt; use nohash_hasher::IntMap; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use std::time::Duration; use text_generation_client::{ Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, }; use thiserror::Error; -use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::StreamExt; use tracing::{info_span, instrument, Instrument, Span}; /// Inference struct @@ -90,7 +89,7 @@ impl Infer { ) -> Result< ( OwnedSemaphorePermit, - RecvStream>, + UnboundedReceiverStream>, ), InferError, > { @@ -113,7 +112,7 @@ impl Infer { })?; // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = flume::unbounded(); + let (response_tx, response_rx) = mpsc::unbounded_channel(); // Append the request to the queue self.queue.append(Entry { @@ -130,7 +129,7 @@ impl Infer { self.shared.batching_task.notify_one(); // Return stream - Ok((permit, response_rx.into_stream())) + Ok((permit, UnboundedReceiverStream::new(response_rx))) } /// Add a new request to the queue and return a InferResponse @@ -493,10 +492,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); err }).unwrap_or(true); @@ -510,9 +506,10 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap Result>>> { +) -> Result>>> { // Return directly if the channel is disconnected - if entry.response_tx.is_disconnected() { + if entry.response_tx.is_closed() { + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); return Ok(true); } @@ -520,10 +517,9 @@ fn send_responses( if let Some(prefill_tokens) = generation.prefill_tokens { // Send message - entry.response_tx.send_timeout( - Ok(InferStreamResponse::Prefill(prefill_tokens)), - Duration::from_millis(10), - )?; + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; } // Create last Token @@ -558,22 +554,18 @@ fn send_responses( // Generation has ended stopped = true; // Send message - entry.response_tx.send_timeout( - Ok(InferStreamResponse::End { - token, - top_tokens, - generated_text, - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }), - Duration::from_millis(10), - )?; + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text, + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; } else { // Send message - entry.response_tx.send_timeout( - Ok(InferStreamResponse::Intermediate { token, top_tokens }), - Duration::from_millis(10), - )?; + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; } Ok(stopped) } @@ -591,7 +583,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // unwrap_or is valid here as we don't care if the receiver is gone. entry .response_tx - .send_timeout(Err(err), Duration::from_millis(10)) + .send(Err(err)) .unwrap_or(()); }); } diff --git a/router/src/main.rs b/router/src/main.rs index f3028674..d90632ef 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -212,7 +212,7 @@ fn main() -> Result<(), RouterError> { // Warmup model tracing::info!("Warming up model"); let max_supported_batch_total_tokens = match sharded_client - .warmup(max_input_length as u32, max_batch_prefill_tokens) + .warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32) .await .map_err(RouterError::Warmup)? { diff --git a/router/src/queue.rs b/router/src/queue.rs index 1ab9eb11..bbb8db0e 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -5,7 +5,7 @@ use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; use std::collections::VecDeque; use text_generation_client::{Batch, Request}; -use tokio::sync::oneshot; +use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; @@ -15,7 +15,7 @@ pub(crate) struct Entry { /// Request pub request: ValidGenerateRequest, /// Response sender to communicate between the Infer struct and the batching_task - pub response_tx: flume::Sender>, + pub response_tx: mpsc::UnboundedSender>, /// Span that will live as long as entry pub span: Span, /// Temporary span used as a guard when logging inference, wait times... @@ -30,13 +30,13 @@ pub(crate) struct Entry { #[derive(Debug, Clone)] pub(crate) struct Queue { /// Channel to communicate with the background queue task - queue_sender: flume::Sender, + queue_sender: mpsc::UnboundedSender, } impl Queue { pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option) -> Self { // Create channel - let (queue_sender, queue_receiver) = flume::unbounded(); + let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); // Launch background queue task tokio::spawn(queue_task( @@ -91,11 +91,11 @@ async fn queue_task( requires_padding: bool, block_size: u32, window_size: Option, - receiver: flume::Receiver, + mut receiver: mpsc::UnboundedReceiver, ) { let mut state = State::new(requires_padding, block_size, window_size); - while let Ok(cmd) = receiver.recv_async().await { + while let Some(cmd) = receiver.recv().await { match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); @@ -195,7 +195,7 @@ impl State { while let Some((id, mut entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) - if entry.response_tx.is_disconnected() { + if entry.response_tx.is_closed() { metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); continue; } @@ -321,9 +321,9 @@ mod tests { fn default_entry() -> ( Entry, - flume::Receiver>, + mpsc::UnboundedReceiver>, ) { - let (response_tx, receiver_tx) = flume::unbounded(); + let (response_tx, receiver_tx) = mpsc::unbounded_channel(); let entry = Entry { request: ValidGenerateRequest { diff --git a/router/src/validation.rs b/router/src/validation.rs index d0ea137d..7a84640d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokenizers::TruncationDirection; +use tokio::sync::mpsc; use tokio::sync::oneshot; use tracing::{instrument, Span}; @@ -19,7 +20,7 @@ pub struct Validation { max_input_length: usize, max_total_tokens: usize, /// Channel to communicate with the background tokenization task - sender: Option>, + sender: Option>, } impl Validation { @@ -34,19 +35,25 @@ impl Validation { ) -> Self { // If we have a fast tokenizer let sender = if let Some(tokenizer) = tokenizer { - // Create channel - let (validation_sender, validation_receiver) = flume::unbounded(); + // Create round robin channel + let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel(); + let mut senders = Vec::with_capacity(workers); // Create workers for _ in 0..workers { let tokenizer_clone = tokenizer.clone(); - let receiver_clone = validation_receiver.clone(); + let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel(); + senders.push(tokenizer_sender); // Spawn worker tokio::task::spawn_blocking(move || { - tokenizer_worker(tokenizer_clone, receiver_clone) + tokenizer_worker(tokenizer_clone, tokenizer_receiver) }); } + + // Create tokenization round robin task + tokio::spawn(round_robin_task(validation_round_robin_receiver, senders)); + Some(validation_sender) } else { None @@ -118,12 +125,10 @@ impl Validation { // We make sure that truncate + max_new_tokens <= self.max_total_tokens let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { max_new_tokens + } else if let Some(truncate) = truncate { + self.max_total_tokens.saturating_sub(truncate) as u32 } else { - if let Some(truncate) = truncate { - self.max_total_tokens.saturating_sub(truncate) as u32 - } else { - return Err(ValidationError::UnsetMaxNewTokens) - } + return Err(ValidationError::UnsetMaxNewTokens); }; let input_length = truncate.unwrap_or(self.max_input_length); @@ -309,10 +314,25 @@ impl Validation { } } +/// Round robin tokenization task +async fn round_robin_task( + mut receiver: mpsc::UnboundedReceiver, + senders: Vec>, +) { + loop { + for sender in &senders { + match receiver.recv().await { + None => return, + Some(request) => sender.send(request).unwrap(), + }; + } + } +} + /// Start tokenization workers -fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver) { +fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver) { // Loop over requests - while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() { + while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { parent_span.in_scope(|| { response_tx .send(prepare_input(inputs, truncate, &tokenizer)) diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index cdea8431..583437b2 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c +flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3 flash-attention-v2: # Clone flash attention diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 2e965da0..c601e452 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,8 +1,8 @@ -vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78 +vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc vllm: # Clone vllm - git clone https://github.com/OlivierDehaene/vllm.git + git clone https://github.com/vllm-project/vllm.git build-vllm: vllm cd vllm && git fetch && git checkout $(vllm_commit) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index fccfb0f8..8056a8ec 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -511,7 +511,7 @@ class CausalLM(Model): load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) - if torch.cuda.is_available() and torch.cuda.device_count() == 1: + if torch.cuda.is_available() and torch.cuda.device_count() == 1 and quantize != "bitsandbytes": model = model.cuda() if tokenizer.pad_token_id is None: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 7c743a88..69608e1c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -29,11 +29,7 @@ from typing import Optional, List, Tuple # Flash attention imports import dropout_layer_norm -# vllm imports -import vllm_cache_ops -import vllm_attention_ops - -from text_generation_server.utils.flash_attn import attention +from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -269,7 +265,7 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - vllm_cache_ops.reshape_and_cache( + paged_attention.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -279,7 +275,7 @@ class FlashLlamaAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -290,9 +286,7 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], @@ -301,7 +295,6 @@ class FlashLlamaAttention(torch.nn.Module): self.softmax_scale, block_tables, input_lengths, - block_size, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 77b7f230..2d731406 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -29,10 +29,7 @@ from typing import Optional, List, Tuple # Flash attention imports import dropout_layer_norm -# vllm imports -import vllm_cache_ops -import vllm_attention_ops - +from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2 from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -272,7 +269,7 @@ class MistralAttention(torch.nn.Module): else: kv_to_cache = kv - vllm_cache_ops.reshape_and_cache( + paged_attention.reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -282,7 +279,7 @@ class MistralAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -294,9 +291,7 @@ class MistralAttention(torch.nn.Module): ) # Decode else: - # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], @@ -305,7 +300,6 @@ class MistralAttention(torch.nn.Module): self.softmax_scale, block_tables, input_lengths, - block_size, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 9dc374df..af4ba96b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -27,10 +27,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig from typing import Optional, List, Tuple -# vllm imports -import vllm_cache_ops -import vllm_attention_ops - +from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -141,7 +138,7 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin) - vllm_cache_ops.reshape_and_cache( + paged_attention.reshape_and_cache( qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots ) @@ -151,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], @@ -162,9 +159,7 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + paged_attention.attention( attn_output, qkv[:, 0], kv_cache[0], @@ -173,7 +168,6 @@ class FlashNeoxAttention(torch.nn.Module): self.softmax_scale, block_tables, input_lengths, - block_size, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 8419fa4f..00f953a6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -6,10 +6,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -# vllm imports -import vllm_cache_ops -import vllm_attention_ops - +from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -191,7 +188,7 @@ class FlashRWAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - vllm_cache_ops.reshape_and_cache( + paged_attention.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -201,7 +198,7 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -212,9 +209,7 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - # kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size] - block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], @@ -223,7 +218,6 @@ class FlashRWAttention(torch.nn.Module): self.softmax_scale, block_tables, input_lengths, - block_size, max_s, ) @@ -310,7 +304,7 @@ class FlashRWLargeAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) - vllm_cache_ops.reshape_and_cache( + paged_attention.reshape_and_cache( kv[:, :, 0].contiguous(), kv[:, :, 1].contiguous(), kv_cache[0], @@ -324,7 +318,7 @@ class FlashRWLargeAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), @@ -335,9 +329,7 @@ class FlashRWLargeAttention(torch.nn.Module): ) # Decode else: - # kv_cache[1] => [num_blocks, num_groups, head_size, block_size] - block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], @@ -346,7 +338,6 @@ class FlashRWLargeAttention(torch.nn.Module): self.softmax_scale, block_tables, input_lengths, - block_size, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 2dd0a5ee..c3c7617a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -5,10 +5,7 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -# vllm imports -import vllm_cache_ops -import vllm_attention_ops - +from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -18,7 +15,6 @@ from text_generation_server.utils.layers import ( FastLayerNorm, get_linear, ) -from safetensors import SafetensorError def load_multi_mqa( @@ -258,7 +254,7 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - vllm_cache_ops.reshape_and_cache( + paged_attention.reshape_and_cache( key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -268,7 +264,7 @@ class FlashMQAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), @@ -279,9 +275,7 @@ class FlashMQAttention(torch.nn.Module): ) # Decode else: - # kv_cache[1] => [num_blocks, 1, head_size, block_size] - block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], @@ -290,7 +284,6 @@ class FlashMQAttention(torch.nn.Module): self.softmax_scale, block_tables, input_lengths, - block_size, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index 24ba6796..dbcefbae 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -283,10 +283,10 @@ class GPTNeoXAttention(nn.Module): batch_size, num_attention_heads, query_length, attn_head_size = query.size() key_length = key.size(-2) - query = query.view( + query = query.reshape( batch_size * num_attention_heads, query_length, attn_head_size ) - key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size) attn_scores = torch.zeros( 1, dtype=query.dtype, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1fe40c0c..f1a4854f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -670,7 +670,7 @@ class FlashCausalLM(Model): self.device, ) _, batch = self.generate_token(batch) - except Exception as e: + except torch.cuda.OutOfMemoryError as e: raise RuntimeError( f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index f38f130e..7bb95dd2 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -155,10 +155,7 @@ class EETQLinear(nn.Module): device = weight.device weight = torch.t(weight).contiguous().cpu() weight, scale = quant_weights(weight, torch.int8, False) - if bias: - bias = weights.get_tensor(f"{prefix}.bias") - else: - bias = None + self.weight = weight.cuda(device) self.scale = scale.cuda(device) self.bias = bias.cuda(device) if bias is not None else None diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py new file mode 100644 index 00000000..57a59599 --- /dev/null +++ b/server/text_generation_server/utils/paged_attention.py @@ -0,0 +1,100 @@ +import torch + +# vllm imports +from vllm import cache_ops +from vllm import attention_ops + +_PARTITION_SIZE = 512 + + +def reshape_and_cache(key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, + slots: torch.Tensor): + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots + ) + + +def attention( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, +): + # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py + # Copyright 2023 The vLLM team. All rights + # reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # + + # value_cache => [num_blocks, num_heads, head_size, block_size] + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ( + (max_s + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 + if use_v1: + attention_ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + attention_ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + )