Merge branch 'huggingface:main' into main

This commit is contained in:
Florian Zimmermeister 2023-10-25 12:18:33 +02:00 committed by GitHub
commit 6a0a378c0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 216 additions and 171 deletions

View File

@ -17,5 +17,4 @@ jobs:
package: text-generation-inference package: text-generation-inference
additional_args: --not_python_module additional_args: --not_python_module
secrets: secrets:
token: ${{ secrets.HUGGINGFACE_PUSH }}
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}

44
Cargo.lock generated
View File

@ -743,18 +743,6 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853" checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
[[package]]
name = "flume"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181"
dependencies = [
"futures-core",
"futures-sink",
"nanorand",
"spin 0.9.8",
]
[[package]] [[package]]
name = "fnv" name = "fnv"
version = "1.0.7" version = "1.0.7"
@ -900,10 +888,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"js-sys",
"libc", "libc",
"wasi", "wasi",
"wasm-bindgen",
] ]
[[package]] [[package]]
@ -1508,15 +1494,6 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "nanorand"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3"
dependencies = [
"getrandom",
]
[[package]] [[package]]
name = "native-tls" name = "native-tls"
version = "0.2.11" version = "0.2.11"
@ -2313,7 +2290,7 @@ dependencies = [
"cc", "cc",
"libc", "libc",
"once_cell", "once_cell",
"spin 0.5.2", "spin",
"untrusted", "untrusted",
"web-sys", "web-sys",
"winapi", "winapi",
@ -2678,15 +2655,6 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]] [[package]]
name = "spm_precompiled" name = "spm_precompiled"
version = "0.1.4" version = "0.1.4"
@ -2808,7 +2776,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "1.1.0" version = "1.1.1"
dependencies = [ dependencies = [
"average", "average",
"clap", "clap",
@ -2829,7 +2797,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "1.1.0" version = "1.1.1"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",
@ -2845,7 +2813,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "1.1.0" version = "1.1.1"
dependencies = [ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
@ -2861,13 +2829,12 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "1.1.0" version = "1.1.1"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum", "axum",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"clap", "clap",
"flume",
"futures", "futures",
"hf-hub 0.3.1", "hf-hub 0.3.1",
"init-tracing-opentelemetry", "init-tracing-opentelemetry",
@ -2885,6 +2852,7 @@ dependencies = [
"thiserror", "thiserror",
"tokenizers", "tokenizers",
"tokio", "tokio",
"tokio-stream",
"tower-http", "tower-http",
"tracing", "tracing",
"tracing-opentelemetry", "tracing-opentelemetry",

View File

@ -4,7 +4,7 @@ Text Generation Inference improves the model in several aspects.
## Quantization ## Quantization
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./conceptual/quantization.md) TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization.md)
## RoPE Scaling ## RoPE Scaling

View File

@ -20,7 +20,6 @@ axum = { version = "0.6.20", features = ["json"] }
axum-tracing-opentelemetry = "0.14.1" axum-tracing-opentelemetry = "0.14.1"
text-generation-client = { path = "client" } text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] } clap = { version = "4.4.5", features = ["derive", "env"] }
flume = "0.11.0"
futures = "0.3.28" futures = "0.3.28"
metrics = "0.21.1" metrics = "0.21.1"
metrics-exporter-prometheus = { version = "0.12.1", features = [] } metrics-exporter-prometheus = { version = "0.12.1", features = [] }
@ -34,6 +33,7 @@ serde_json = "1.0.107"
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { version = "0.14.0", features = ["http"] } tokenizers = { version = "0.14.0", features = ["http"] }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14"
tower-http = { version = "0.4.4", features = ["cors"] } tower-http = { version = "0.4.4", features = ["cors"] }
tracing = "0.1.37" tracing = "0.1.37"
tracing-opentelemetry = "0.21.0" tracing-opentelemetry = "0.21.0"

View File

@ -103,17 +103,18 @@ impl Client {
&mut self, &mut self,
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let mut n_tokens = 0; let mut n_tokens = 0;
let mut requests = Vec::new(); let mut requests = Vec::new();
// Create requests // Create requests
while n_tokens < max_prefill_tokens { while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
requests.push(Request { requests.push(Request {
id: 0, id: 0,
// We truncate the input on the server side to be sure that it has the correct size // We truncate the input on the server side to be sure that it has the correct size
inputs: "_test ".to_string().repeat(max_input_length as usize), inputs: "_test ".to_string().repeat(max_input_length as usize),
truncate: min(max_input_length, max_prefill_tokens - n_tokens), truncate,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,
@ -126,9 +127,9 @@ impl Client {
watermark: true, watermark: true,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 2, max_new_tokens: max_total_tokens - truncate,
stop_sequences: vec![], stop_sequences: vec![],
ignore_eos_token: false, ignore_eos_token: true,
}), }),
prefill_logprobs: true, prefill_logprobs: true,
top_n_tokens: 20, top_n_tokens: 20,

View File

@ -95,11 +95,14 @@ impl ShardedClient {
&mut self, &mut self,
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens))) .map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
})
.collect(); .collect();
// Take the minimum value // Take the minimum value
let results = join_all(futures) let results = join_all(futures)

View File

@ -2,22 +2,21 @@
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::{Entry, Queue, Token}; use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken}; use crate::{GenerateRequest, PrefillToken};
use flume::r#async::RecvStream;
use flume::SendTimeoutError;
use futures::future::try_join_all; use futures::future::try_join_all;
use futures::stream::StreamExt;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
}; };
use std::time::Duration;
use text_generation_client::{ use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
}; };
use thiserror::Error; use thiserror::Error;
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
/// Inference struct /// Inference struct
@ -90,7 +89,7 @@ impl Infer {
) -> Result< ) -> Result<
( (
OwnedSemaphorePermit, OwnedSemaphorePermit,
RecvStream<Result<InferStreamResponse, InferError>>, UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
), ),
InferError, InferError,
> { > {
@ -113,7 +112,7 @@ impl Infer {
})?; })?;
// MPSC channel to communicate with the background batching task // MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = flume::unbounded(); let (response_tx, response_rx) = mpsc::unbounded_channel();
// Append the request to the queue // Append the request to the queue
self.queue.append(Entry { self.queue.append(Entry {
@ -130,7 +129,7 @@ impl Infer {
self.shared.batching_task.notify_one(); self.shared.batching_task.notify_one();
// Return stream // Return stream
Ok((permit, response_rx.into_stream())) Ok((permit, UnboundedReceiverStream::new(response_rx)))
} }
/// Add a new request to the queue and return a InferResponse /// Add a new request to the queue and return a InferResponse
@ -493,10 +492,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// If the receive an error from the Flume channel, it means that the client dropped the // If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true) // request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| { let stopped = send_responses(generation, entry).map_err(|err| {
if let SendTimeoutError::Timeout(_) = *err { tracing::error!("Entry response channel error.");
tracing::error!("Entry response channel timed out.")
}
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err err
}).unwrap_or(true); }).unwrap_or(true);
@ -510,9 +506,10 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
fn send_responses( fn send_responses(
generation: Generation, generation: Generation,
entry: &Entry, entry: &Entry,
) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> { ) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected // Return directly if the channel is disconnected
if entry.response_tx.is_disconnected() { if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
return Ok(true); return Ok(true);
} }
@ -520,10 +517,9 @@ fn send_responses(
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message // Send message
entry.response_tx.send_timeout( entry
Ok(InferStreamResponse::Prefill(prefill_tokens)), .response_tx
Duration::from_millis(10), .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
)?;
} }
// Create last Token // Create last Token
@ -558,22 +554,18 @@ fn send_responses(
// Generation has ended // Generation has ended
stopped = true; stopped = true;
// Send message // Send message
entry.response_tx.send_timeout( entry.response_tx.send(Ok(InferStreamResponse::End {
Ok(InferStreamResponse::End {
token, token,
top_tokens, top_tokens,
generated_text, generated_text,
queued: entry.queue_time, queued: entry.queue_time,
start: entry.batch_time.unwrap(), start: entry.batch_time.unwrap(),
}), }))?;
Duration::from_millis(10),
)?;
} else { } else {
// Send message // Send message
entry.response_tx.send_timeout( entry
Ok(InferStreamResponse::Intermediate { token, top_tokens }), .response_tx
Duration::from_millis(10), .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
)?;
} }
Ok(stopped) Ok(stopped)
} }
@ -591,7 +583,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.
entry entry
.response_tx .response_tx
.send_timeout(Err(err), Duration::from_millis(10)) .send(Err(err))
.unwrap_or(()); .unwrap_or(());
}); });
} }

View File

@ -212,7 +212,7 @@ fn main() -> Result<(), RouterError> {
// Warmup model // Warmup model
tracing::info!("Warming up model"); tracing::info!("Warming up model");
let max_supported_batch_total_tokens = match sharded_client let max_supported_batch_total_tokens = match sharded_client
.warmup(max_input_length as u32, max_batch_prefill_tokens) .warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32)
.await .await
.map_err(RouterError::Warmup)? .map_err(RouterError::Warmup)?
{ {

View File

@ -5,7 +5,7 @@ use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min; use std::cmp::min;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_client::{Batch, Request}; use text_generation_client::{Batch, Request};
use tokio::sync::oneshot; use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Span}; use tracing::{info_span, instrument, Span};
@ -15,7 +15,7 @@ pub(crate) struct Entry {
/// Request /// Request
pub request: ValidGenerateRequest, pub request: ValidGenerateRequest,
/// Response sender to communicate between the Infer struct and the batching_task /// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: flume::Sender<Result<InferStreamResponse, InferError>>, pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Span that will live as long as entry /// Span that will live as long as entry
pub span: Span, pub span: Span,
/// Temporary span used as a guard when logging inference, wait times... /// Temporary span used as a guard when logging inference, wait times...
@ -30,13 +30,13 @@ pub(crate) struct Entry {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct Queue { pub(crate) struct Queue {
/// Channel to communicate with the background queue task /// Channel to communicate with the background queue task
queue_sender: flume::Sender<QueueCommand>, queue_sender: mpsc::UnboundedSender<QueueCommand>,
} }
impl Queue { impl Queue {
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self { pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = flume::unbounded(); let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
// Launch background queue task // Launch background queue task
tokio::spawn(queue_task( tokio::spawn(queue_task(
@ -91,11 +91,11 @@ async fn queue_task(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
receiver: flume::Receiver<QueueCommand>, mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) { ) {
let mut state = State::new(requires_padding, block_size, window_size); let mut state = State::new(requires_padding, block_size, window_size);
while let Ok(cmd) = receiver.recv_async().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
QueueCommand::Append(entry, span) => { QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(*entry)); span.in_scope(|| state.append(*entry));
@ -195,7 +195,7 @@ impl State {
while let Some((id, mut entry)) = self.entries.pop_front() { while let Some((id, mut entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request // Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client) // was dropped by the client)
if entry.response_tx.is_disconnected() { if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
continue; continue;
} }
@ -321,9 +321,9 @@ mod tests {
fn default_entry() -> ( fn default_entry() -> (
Entry, Entry,
flume::Receiver<Result<InferStreamResponse, InferError>>, mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
) { ) {
let (response_tx, receiver_tx) = flume::unbounded(); let (response_tx, receiver_tx) = mpsc::unbounded_channel();
let entry = Entry { let entry = Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {

View File

@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection; use tokenizers::TruncationDirection;
use tokio::sync::mpsc;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tracing::{instrument, Span}; use tracing::{instrument, Span};
@ -19,7 +20,7 @@ pub struct Validation {
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
/// Channel to communicate with the background tokenization task /// Channel to communicate with the background tokenization task
sender: Option<flume::Sender<TokenizerRequest>>, sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
} }
impl Validation { impl Validation {
@ -34,19 +35,25 @@ impl Validation {
) -> Self { ) -> Self {
// If we have a fast tokenizer // If we have a fast tokenizer
let sender = if let Some(tokenizer) = tokenizer { let sender = if let Some(tokenizer) = tokenizer {
// Create channel // Create round robin channel
let (validation_sender, validation_receiver) = flume::unbounded(); let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
let mut senders = Vec::with_capacity(workers);
// Create workers // Create workers
for _ in 0..workers { for _ in 0..workers {
let tokenizer_clone = tokenizer.clone(); let tokenizer_clone = tokenizer.clone();
let receiver_clone = validation_receiver.clone(); let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
senders.push(tokenizer_sender);
// Spawn worker // Spawn worker
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
tokenizer_worker(tokenizer_clone, receiver_clone) tokenizer_worker(tokenizer_clone, tokenizer_receiver)
}); });
} }
// Create tokenization round robin task
tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));
Some(validation_sender) Some(validation_sender)
} else { } else {
None None
@ -118,12 +125,10 @@ impl Validation {
// We make sure that truncate + max_new_tokens <= self.max_total_tokens // We make sure that truncate + max_new_tokens <= self.max_total_tokens
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
max_new_tokens max_new_tokens
} else { } else if let Some(truncate) = truncate {
if let Some(truncate) = truncate {
self.max_total_tokens.saturating_sub(truncate) as u32 self.max_total_tokens.saturating_sub(truncate) as u32
} else { } else {
return Err(ValidationError::UnsetMaxNewTokens) return Err(ValidationError::UnsetMaxNewTokens);
}
}; };
let input_length = truncate.unwrap_or(self.max_input_length); let input_length = truncate.unwrap_or(self.max_input_length);
@ -309,10 +314,25 @@ impl Validation {
} }
} }
/// Round robin tokenization task
async fn round_robin_task(
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
) {
loop {
for sender in &senders {
match receiver.recv().await {
None => return,
Some(request) => sender.send(request).unwrap(),
};
}
}
}
/// Start tokenization workers /// Start tokenization workers
fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver<TokenizerRequest>) { fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
// Loop over requests // Loop over requests
while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() { while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
parent_span.in_scope(|| { parent_span.in_scope(|| {
response_tx response_tx
.send(prepare_input(inputs, truncate, &tokenizer)) .send(prepare_input(inputs, truncate, &tokenizer))

View File

@ -1,4 +1,4 @@
flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
flash-attention-v2: flash-attention-v2:
# Clone flash attention # Clone flash attention

View File

@ -1,8 +1,8 @@
vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78 vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
vllm: vllm:
# Clone vllm # Clone vllm
git clone https://github.com/OlivierDehaene/vllm.git git clone https://github.com/vllm-project/vllm.git
build-vllm: vllm build-vllm: vllm
cd vllm && git fetch && git checkout $(vllm_commit) cd vllm && git fetch && git checkout $(vllm_commit)

View File

@ -511,7 +511,7 @@ class CausalLM(Model):
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1: if torch.cuda.is_available() and torch.cuda.device_count() == 1 and quantize != "bitsandbytes":
model = model.cuda() model = model.cuda()
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:

View File

@ -29,11 +29,7 @@ from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
import dropout_layer_norm import dropout_layer_norm
# vllm imports from text_generation_server.utils import paged_attention, flash_attn
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -269,7 +265,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache( paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
) )
@ -279,7 +275,7 @@ class FlashLlamaAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( flash_attn.attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
@ -290,9 +286,7 @@ class FlashLlamaAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size] paged_attention.attention(
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
@ -301,7 +295,6 @@ class FlashLlamaAttention(torch.nn.Module):
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
block_size,
max_s, max_s,
) )

View File

@ -29,10 +29,7 @@ from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
import dropout_layer_norm import dropout_layer_norm
# vllm imports from text_generation_server.utils import paged_attention, flash_attn
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2 from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -272,7 +269,7 @@ class MistralAttention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
vllm_cache_ops.reshape_and_cache( paged_attention.reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
@ -282,7 +279,7 @@ class MistralAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( flash_attn.attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
@ -294,9 +291,7 @@ class MistralAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size] paged_attention.attention(
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
@ -305,7 +300,6 @@ class MistralAttention(torch.nn.Module):
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
block_size,
max_s, max_s,
) )

View File

@ -27,10 +27,7 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# vllm imports from text_generation_server.utils import paged_attention, flash_attn
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -141,7 +138,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin)
vllm_cache_ops.reshape_and_cache( paged_attention.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
) )
@ -151,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( flash_attn.attention(
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
@ -162,9 +159,7 @@ class FlashNeoxAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size] paged_attention.attention(
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output, attn_output,
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache[0],
@ -173,7 +168,6 @@ class FlashNeoxAttention(torch.nn.Module):
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
block_size,
max_s, max_s,
) )

View File

@ -6,10 +6,7 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# vllm imports from text_generation_server.utils import paged_attention, flash_attn
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -191,7 +188,7 @@ class FlashRWAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache( paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
) )
@ -201,7 +198,7 @@ class FlashRWAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( flash_attn.attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
@ -212,9 +209,7 @@ class FlashRWAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size] paged_attention.attention(
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
@ -223,7 +218,6 @@ class FlashRWAttention(torch.nn.Module):
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
block_size,
max_s, max_s,
) )
@ -310,7 +304,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache( paged_attention.reshape_and_cache(
kv[:, :, 0].contiguous(), kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(), kv[:, :, 1].contiguous(),
kv_cache[0], kv_cache[0],
@ -324,7 +318,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( flash_attn.attention(
query, query,
torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1), torch.select(kv, dim=2, index=1),
@ -335,9 +329,7 @@ class FlashRWLargeAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size] paged_attention.attention(
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
@ -346,7 +338,6 @@ class FlashRWLargeAttention(torch.nn.Module):
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
block_size,
max_s, max_s,
) )

View File

@ -5,10 +5,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# vllm imports from text_generation_server.utils import paged_attention, flash_attn
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -18,7 +15,6 @@ from text_generation_server.utils.layers import (
FastLayerNorm, FastLayerNorm,
get_linear, get_linear,
) )
from safetensors import SafetensorError
def load_multi_mqa( def load_multi_mqa(
@ -258,7 +254,7 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size)
vllm_cache_ops.reshape_and_cache( paged_attention.reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
) )
@ -268,7 +264,7 @@ class FlashMQAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( flash_attn.attention(
query, query,
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
@ -279,9 +275,7 @@ class FlashMQAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
# kv_cache[1] => [num_blocks, 1, head_size, block_size] paged_attention.attention(
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
@ -290,7 +284,6 @@ class FlashMQAttention(torch.nn.Module):
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
block_size,
max_s, max_s,
) )

View File

@ -283,10 +283,10 @@ class GPTNeoXAttention(nn.Module):
batch_size, num_attention_heads, query_length, attn_head_size = query.size() batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2) key_length = key.size(-2)
query = query.view( query = query.reshape(
batch_size * num_attention_heads, query_length, attn_head_size batch_size * num_attention_heads, query_length, attn_head_size
) )
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros( attn_scores = torch.zeros(
1, 1,
dtype=query.dtype, dtype=query.dtype,

View File

@ -670,7 +670,7 @@ class FlashCausalLM(Model):
self.device, self.device,
) )
_, batch = self.generate_token(batch) _, batch = self.generate_token(batch)
except Exception as e: except torch.cuda.OutOfMemoryError as e:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"

View File

@ -155,10 +155,7 @@ class EETQLinear(nn.Module):
device = weight.device device = weight.device
weight = torch.t(weight).contiguous().cpu() weight = torch.t(weight).contiguous().cpu()
weight, scale = quant_weights(weight, torch.int8, False) weight, scale = quant_weights(weight, torch.int8, False)
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
self.weight = weight.cuda(device) self.weight = weight.cuda(device)
self.scale = scale.cuda(device) self.scale = scale.cuda(device)
self.bias = bias.cuda(device) if bias is not None else None self.bias = bias.cuda(device) if bias is not None else None

View File

@ -0,0 +1,100 @@
import torch
# vllm imports
from vllm import cache_ops
from vllm import attention_ops
_PARTITION_SIZE = 512
def reshape_and_cache(key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor):
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
def attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (
(max_s + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
if use_v1:
attention_ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
attention_ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)