mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Merge branch 'huggingface:main' into main
This commit is contained in:
commit
6a0a378c0c
1
.github/workflows/build_documentation.yml
vendored
1
.github/workflows/build_documentation.yml
vendored
@ -17,5 +17,4 @@ jobs:
|
|||||||
package: text-generation-inference
|
package: text-generation-inference
|
||||||
additional_args: --not_python_module
|
additional_args: --not_python_module
|
||||||
secrets:
|
secrets:
|
||||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
|
||||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
44
Cargo.lock
generated
44
Cargo.lock
generated
@ -743,18 +743,6 @@ version = "1.0.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
|
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "flume"
|
|
||||||
version = "0.11.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181"
|
|
||||||
dependencies = [
|
|
||||||
"futures-core",
|
|
||||||
"futures-sink",
|
|
||||||
"nanorand",
|
|
||||||
"spin 0.9.8",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fnv"
|
name = "fnv"
|
||||||
version = "1.0.7"
|
version = "1.0.7"
|
||||||
@ -900,10 +888,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427"
|
checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"js-sys",
|
|
||||||
"libc",
|
"libc",
|
||||||
"wasi",
|
"wasi",
|
||||||
"wasm-bindgen",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1508,15 +1494,6 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "nanorand"
|
|
||||||
version = "0.7.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3"
|
|
||||||
dependencies = [
|
|
||||||
"getrandom",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "native-tls"
|
name = "native-tls"
|
||||||
version = "0.2.11"
|
version = "0.2.11"
|
||||||
@ -2313,7 +2290,7 @@ dependencies = [
|
|||||||
"cc",
|
"cc",
|
||||||
"libc",
|
"libc",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"spin 0.5.2",
|
"spin",
|
||||||
"untrusted",
|
"untrusted",
|
||||||
"web-sys",
|
"web-sys",
|
||||||
"winapi",
|
"winapi",
|
||||||
@ -2678,15 +2655,6 @@ version = "0.5.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
|
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "spin"
|
|
||||||
version = "0.9.8"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
|
||||||
dependencies = [
|
|
||||||
"lock_api",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "spm_precompiled"
|
name = "spm_precompiled"
|
||||||
version = "0.1.4"
|
version = "0.1.4"
|
||||||
@ -2808,7 +2776,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "1.1.0"
|
version = "1.1.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap",
|
"clap",
|
||||||
@ -2829,7 +2797,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "1.1.0"
|
version = "1.1.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"grpc-metadata",
|
"grpc-metadata",
|
||||||
@ -2845,7 +2813,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "1.1.0"
|
version = "1.1.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
@ -2861,13 +2829,12 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "1.1.0"
|
version = "1.1.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"axum",
|
"axum",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"clap",
|
"clap",
|
||||||
"flume",
|
|
||||||
"futures",
|
"futures",
|
||||||
"hf-hub 0.3.1",
|
"hf-hub 0.3.1",
|
||||||
"init-tracing-opentelemetry",
|
"init-tracing-opentelemetry",
|
||||||
@ -2885,6 +2852,7 @@ dependencies = [
|
|||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-stream",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry",
|
"tracing-opentelemetry",
|
||||||
|
@ -4,7 +4,7 @@ Text Generation Inference improves the model in several aspects.
|
|||||||
|
|
||||||
## Quantization
|
## Quantization
|
||||||
|
|
||||||
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./conceptual/quantization.md)
|
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization.md)
|
||||||
|
|
||||||
|
|
||||||
## RoPE Scaling
|
## RoPE Scaling
|
||||||
|
@ -20,7 +20,6 @@ axum = { version = "0.6.20", features = ["json"] }
|
|||||||
axum-tracing-opentelemetry = "0.14.1"
|
axum-tracing-opentelemetry = "0.14.1"
|
||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
flume = "0.11.0"
|
|
||||||
futures = "0.3.28"
|
futures = "0.3.28"
|
||||||
metrics = "0.21.1"
|
metrics = "0.21.1"
|
||||||
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
||||||
@ -34,6 +33,7 @@ serde_json = "1.0.107"
|
|||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { version = "0.14.0", features = ["http"] }
|
tokenizers = { version = "0.14.0", features = ["http"] }
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
|
tokio-stream = "0.1.14"
|
||||||
tower-http = { version = "0.4.4", features = ["cors"] }
|
tower-http = { version = "0.4.4", features = ["cors"] }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-opentelemetry = "0.21.0"
|
tracing-opentelemetry = "0.21.0"
|
||||||
|
@ -103,17 +103,18 @@ impl Client {
|
|||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
|
|
||||||
// Create requests
|
// Create requests
|
||||||
while n_tokens < max_prefill_tokens {
|
while n_tokens < max_prefill_tokens {
|
||||||
|
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
// We truncate the input on the server side to be sure that it has the correct size
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
||||||
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
|
truncate,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
@ -126,9 +127,9 @@ impl Client {
|
|||||||
watermark: true,
|
watermark: true,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 2,
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: true,
|
||||||
}),
|
}),
|
||||||
prefill_logprobs: true,
|
prefill_logprobs: true,
|
||||||
top_n_tokens: 20,
|
top_n_tokens: 20,
|
||||||
|
@ -95,11 +95,14 @@ impl ShardedClient {
|
|||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
|
.map(|client| {
|
||||||
|
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
// Take the minimum value
|
// Take the minimum value
|
||||||
let results = join_all(futures)
|
let results = join_all(futures)
|
||||||
|
@ -2,22 +2,21 @@
|
|||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
use crate::{Entry, Queue, Token};
|
use crate::{Entry, Queue, Token};
|
||||||
use crate::{GenerateRequest, PrefillToken};
|
use crate::{GenerateRequest, PrefillToken};
|
||||||
use flume::r#async::RecvStream;
|
|
||||||
use flume::SendTimeoutError;
|
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use futures::stream::StreamExt;
|
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
atomic::{AtomicBool, Ordering},
|
atomic::{AtomicBool, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
};
|
};
|
||||||
use std::time::Duration;
|
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
use tokio::sync::mpsc::error::SendError;
|
||||||
|
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tokio_stream::StreamExt;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
/// Inference struct
|
/// Inference struct
|
||||||
@ -90,7 +89,7 @@ impl Infer {
|
|||||||
) -> Result<
|
) -> Result<
|
||||||
(
|
(
|
||||||
OwnedSemaphorePermit,
|
OwnedSemaphorePermit,
|
||||||
RecvStream<Result<InferStreamResponse, InferError>>,
|
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||||
),
|
),
|
||||||
InferError,
|
InferError,
|
||||||
> {
|
> {
|
||||||
@ -113,7 +112,7 @@ impl Infer {
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
// MPSC channel to communicate with the background batching task
|
// MPSC channel to communicate with the background batching task
|
||||||
let (response_tx, response_rx) = flume::unbounded();
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
// Append the request to the queue
|
// Append the request to the queue
|
||||||
self.queue.append(Entry {
|
self.queue.append(Entry {
|
||||||
@ -130,7 +129,7 @@ impl Infer {
|
|||||||
self.shared.batching_task.notify_one();
|
self.shared.batching_task.notify_one();
|
||||||
|
|
||||||
// Return stream
|
// Return stream
|
||||||
Ok((permit, response_rx.into_stream()))
|
Ok((permit, UnboundedReceiverStream::new(response_rx)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new request to the queue and return a InferResponse
|
/// Add a new request to the queue and return a InferResponse
|
||||||
@ -493,10 +492,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||||
// request and we need to stop generating hence why we unwrap_or(true)
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||||
if let SendTimeoutError::Timeout(_) = *err {
|
tracing::error!("Entry response channel error.");
|
||||||
tracing::error!("Entry response channel timed out.")
|
|
||||||
}
|
|
||||||
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
err
|
err
|
||||||
}).unwrap_or(true);
|
}).unwrap_or(true);
|
||||||
@ -510,9 +506,10 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
fn send_responses(
|
fn send_responses(
|
||||||
generation: Generation,
|
generation: Generation,
|
||||||
entry: &Entry,
|
entry: &Entry,
|
||||||
) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> {
|
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||||
// Return directly if the channel is disconnected
|
// Return directly if the channel is disconnected
|
||||||
if entry.response_tx.is_disconnected() {
|
if entry.response_tx.is_closed() {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
return Ok(true);
|
return Ok(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -520,10 +517,9 @@ fn send_responses(
|
|||||||
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
// Send message
|
// Send message
|
||||||
entry.response_tx.send_timeout(
|
entry
|
||||||
Ok(InferStreamResponse::Prefill(prefill_tokens)),
|
.response_tx
|
||||||
Duration::from_millis(10),
|
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||||
)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create last Token
|
// Create last Token
|
||||||
@ -558,22 +554,18 @@ fn send_responses(
|
|||||||
// Generation has ended
|
// Generation has ended
|
||||||
stopped = true;
|
stopped = true;
|
||||||
// Send message
|
// Send message
|
||||||
entry.response_tx.send_timeout(
|
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||||
Ok(InferStreamResponse::End {
|
token,
|
||||||
token,
|
top_tokens,
|
||||||
top_tokens,
|
generated_text,
|
||||||
generated_text,
|
queued: entry.queue_time,
|
||||||
queued: entry.queue_time,
|
start: entry.batch_time.unwrap(),
|
||||||
start: entry.batch_time.unwrap(),
|
}))?;
|
||||||
}),
|
|
||||||
Duration::from_millis(10),
|
|
||||||
)?;
|
|
||||||
} else {
|
} else {
|
||||||
// Send message
|
// Send message
|
||||||
entry.response_tx.send_timeout(
|
entry
|
||||||
Ok(InferStreamResponse::Intermediate { token, top_tokens }),
|
.response_tx
|
||||||
Duration::from_millis(10),
|
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||||
)?;
|
|
||||||
}
|
}
|
||||||
Ok(stopped)
|
Ok(stopped)
|
||||||
}
|
}
|
||||||
@ -591,7 +583,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
entry
|
entry
|
||||||
.response_tx
|
.response_tx
|
||||||
.send_timeout(Err(err), Duration::from_millis(10))
|
.send(Err(err))
|
||||||
.unwrap_or(());
|
.unwrap_or(());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -212,7 +212,7 @@ fn main() -> Result<(), RouterError> {
|
|||||||
// Warmup model
|
// Warmup model
|
||||||
tracing::info!("Warming up model");
|
tracing::info!("Warming up model");
|
||||||
let max_supported_batch_total_tokens = match sharded_client
|
let max_supported_batch_total_tokens = match sharded_client
|
||||||
.warmup(max_input_length as u32, max_batch_prefill_tokens)
|
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32)
|
||||||
.await
|
.await
|
||||||
.map_err(RouterError::Warmup)?
|
.map_err(RouterError::Warmup)?
|
||||||
{
|
{
|
||||||
|
@ -5,7 +5,7 @@ use nohash_hasher::{BuildNoHashHasher, IntMap};
|
|||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_client::{Batch, Request};
|
use text_generation_client::{Batch, Request};
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Span};
|
use tracing::{info_span, instrument, Span};
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ pub(crate) struct Entry {
|
|||||||
/// Request
|
/// Request
|
||||||
pub request: ValidGenerateRequest,
|
pub request: ValidGenerateRequest,
|
||||||
/// Response sender to communicate between the Infer struct and the batching_task
|
/// Response sender to communicate between the Infer struct and the batching_task
|
||||||
pub response_tx: flume::Sender<Result<InferStreamResponse, InferError>>,
|
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||||
/// Span that will live as long as entry
|
/// Span that will live as long as entry
|
||||||
pub span: Span,
|
pub span: Span,
|
||||||
/// Temporary span used as a guard when logging inference, wait times...
|
/// Temporary span used as a guard when logging inference, wait times...
|
||||||
@ -30,13 +30,13 @@ pub(crate) struct Entry {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct Queue {
|
pub(crate) struct Queue {
|
||||||
/// Channel to communicate with the background queue task
|
/// Channel to communicate with the background queue task
|
||||||
queue_sender: flume::Sender<QueueCommand>,
|
queue_sender: mpsc::UnboundedSender<QueueCommand>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Queue {
|
impl Queue {
|
||||||
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
// Launch background queue task
|
// Launch background queue task
|
||||||
tokio::spawn(queue_task(
|
tokio::spawn(queue_task(
|
||||||
@ -91,11 +91,11 @@ async fn queue_task(
|
|||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
receiver: flume::Receiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(requires_padding, block_size, window_size);
|
let mut state = State::new(requires_padding, block_size, window_size);
|
||||||
|
|
||||||
while let Ok(cmd) = receiver.recv_async().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
QueueCommand::Append(entry, span) => {
|
QueueCommand::Append(entry, span) => {
|
||||||
span.in_scope(|| state.append(*entry));
|
span.in_scope(|| state.append(*entry));
|
||||||
@ -195,7 +195,7 @@ impl State {
|
|||||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
// was dropped by the client)
|
// was dropped by the client)
|
||||||
if entry.response_tx.is_disconnected() {
|
if entry.response_tx.is_closed() {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -321,9 +321,9 @@ mod tests {
|
|||||||
|
|
||||||
fn default_entry() -> (
|
fn default_entry() -> (
|
||||||
Entry,
|
Entry,
|
||||||
flume::Receiver<Result<InferStreamResponse, InferError>>,
|
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
|
||||||
) {
|
) {
|
||||||
let (response_tx, receiver_tx) = flume::unbounded();
|
let (response_tx, receiver_tx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
let entry = Entry {
|
let entry = Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
|
@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokenizers::TruncationDirection;
|
use tokenizers::TruncationDirection;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tracing::{instrument, Span};
|
use tracing::{instrument, Span};
|
||||||
|
|
||||||
@ -19,7 +20,7 @@ pub struct Validation {
|
|||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
/// Channel to communicate with the background tokenization task
|
/// Channel to communicate with the background tokenization task
|
||||||
sender: Option<flume::Sender<TokenizerRequest>>,
|
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Validation {
|
impl Validation {
|
||||||
@ -34,19 +35,25 @@ impl Validation {
|
|||||||
) -> Self {
|
) -> Self {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
let sender = if let Some(tokenizer) = tokenizer {
|
let sender = if let Some(tokenizer) = tokenizer {
|
||||||
// Create channel
|
// Create round robin channel
|
||||||
let (validation_sender, validation_receiver) = flume::unbounded();
|
let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
|
||||||
|
let mut senders = Vec::with_capacity(workers);
|
||||||
|
|
||||||
// Create workers
|
// Create workers
|
||||||
for _ in 0..workers {
|
for _ in 0..workers {
|
||||||
let tokenizer_clone = tokenizer.clone();
|
let tokenizer_clone = tokenizer.clone();
|
||||||
let receiver_clone = validation_receiver.clone();
|
let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
|
||||||
|
senders.push(tokenizer_sender);
|
||||||
|
|
||||||
// Spawn worker
|
// Spawn worker
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
tokenizer_worker(tokenizer_clone, receiver_clone)
|
tokenizer_worker(tokenizer_clone, tokenizer_receiver)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create tokenization round robin task
|
||||||
|
tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));
|
||||||
|
|
||||||
Some(validation_sender)
|
Some(validation_sender)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -118,12 +125,10 @@ impl Validation {
|
|||||||
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
||||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
||||||
max_new_tokens
|
max_new_tokens
|
||||||
|
} else if let Some(truncate) = truncate {
|
||||||
|
self.max_total_tokens.saturating_sub(truncate) as u32
|
||||||
} else {
|
} else {
|
||||||
if let Some(truncate) = truncate {
|
return Err(ValidationError::UnsetMaxNewTokens);
|
||||||
self.max_total_tokens.saturating_sub(truncate) as u32
|
|
||||||
} else {
|
|
||||||
return Err(ValidationError::UnsetMaxNewTokens)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||||
|
|
||||||
@ -309,10 +314,25 @@ impl Validation {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Round robin tokenization task
|
||||||
|
async fn round_robin_task(
|
||||||
|
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||||
|
senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
|
||||||
|
) {
|
||||||
|
loop {
|
||||||
|
for sender in &senders {
|
||||||
|
match receiver.recv().await {
|
||||||
|
None => return,
|
||||||
|
Some(request) => sender.send(request).unwrap(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Start tokenization workers
|
/// Start tokenization workers
|
||||||
fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver<TokenizerRequest>) {
|
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
|
||||||
// Loop over requests
|
// Loop over requests
|
||||||
while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() {
|
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
|
||||||
parent_span.in_scope(|| {
|
parent_span.in_scope(|| {
|
||||||
response_tx
|
response_tx
|
||||||
.send(prepare_input(inputs, truncate, &tokenizer))
|
.send(prepare_input(inputs, truncate, &tokenizer))
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c
|
flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
||||||
|
|
||||||
flash-attention-v2:
|
flash-attention-v2:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78
|
vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||||
|
|
||||||
vllm:
|
vllm:
|
||||||
# Clone vllm
|
# Clone vllm
|
||||||
git clone https://github.com/OlivierDehaene/vllm.git
|
git clone https://github.com/vllm-project/vllm.git
|
||||||
|
|
||||||
build-vllm: vllm
|
build-vllm: vllm
|
||||||
cd vllm && git fetch && git checkout $(vllm_commit)
|
cd vllm && git fetch && git checkout $(vllm_commit)
|
||||||
|
@ -511,7 +511,7 @@ class CausalLM(Model):
|
|||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and quantize != "bitsandbytes":
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
|
@ -29,11 +29,7 @@ from typing import Optional, List, Tuple
|
|||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
# vllm imports
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
import vllm_cache_ops
|
|
||||||
import vllm_attention_ops
|
|
||||||
|
|
||||||
from text_generation_server.utils.flash_attn import attention
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -269,7 +265,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -279,7 +275,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
flash_attn.attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -290,9 +286,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
paged_attention.attention(
|
||||||
block_size = kv_cache[1].shape[3]
|
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -301,7 +295,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -29,10 +29,7 @@ from typing import Optional, List, Tuple
|
|||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
# vllm imports
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
import vllm_cache_ops
|
|
||||||
import vllm_attention_ops
|
|
||||||
|
|
||||||
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
|
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -272,7 +269,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -282,7 +279,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
flash_attn.attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -294,9 +291,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
paged_attention.attention(
|
||||||
block_size = kv_cache[1].shape[3]
|
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -305,7 +300,6 @@ class MistralAttention(torch.nn.Module):
|
|||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,10 +27,7 @@ from transformers.modeling_utils import PreTrainedModel
|
|||||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# vllm imports
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
import vllm_cache_ops
|
|
||||||
import vllm_attention_ops
|
|
||||||
|
|
||||||
from text_generation_server.utils.flash_attn import attention
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -141,7 +138,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -151,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
flash_attn.attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
@ -162,9 +159,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
paged_attention.attention(
|
||||||
block_size = kv_cache[1].shape[3]
|
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
|
||||||
attn_output,
|
attn_output,
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -173,7 +168,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -6,10 +6,7 @@ from transformers.modeling_utils import PreTrainedModel
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# vllm imports
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
import vllm_cache_ops
|
|
||||||
import vllm_attention_ops
|
|
||||||
|
|
||||||
from text_generation_server.utils.flash_attn import attention
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -191,7 +188,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -201,7 +198,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
flash_attn.attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -212,9 +209,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
|
paged_attention.attention(
|
||||||
block_size = kv_cache[1].shape[3]
|
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -223,7 +218,6 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -310,7 +304,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, :, 0].contiguous(),
|
kv[:, :, 0].contiguous(),
|
||||||
kv[:, :, 1].contiguous(),
|
kv[:, :, 1].contiguous(),
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -324,7 +318,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
flash_attn.attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
torch.select(kv, dim=2, index=0),
|
||||||
torch.select(kv, dim=2, index=1),
|
torch.select(kv, dim=2, index=1),
|
||||||
@ -335,9 +329,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
|
paged_attention.attention(
|
||||||
block_size = kv_cache[1].shape[3]
|
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -346,7 +338,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -5,10 +5,7 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# vllm imports
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
import vllm_cache_ops
|
|
||||||
import vllm_attention_ops
|
|
||||||
|
|
||||||
from text_generation_server.utils.flash_attn import attention
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -18,7 +15,6 @@ from text_generation_server.utils.layers import (
|
|||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from safetensors import SafetensorError
|
|
||||||
|
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
@ -258,7 +254,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -268,7 +264,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
flash_attn.attention(
|
||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
@ -279,9 +275,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# kv_cache[1] => [num_blocks, 1, head_size, block_size]
|
paged_attention.attention(
|
||||||
block_size = kv_cache[1].shape[3]
|
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -290,7 +284,6 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -283,10 +283,10 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||||
key_length = key.size(-2)
|
key_length = key.size(-2)
|
||||||
|
|
||||||
query = query.view(
|
query = query.reshape(
|
||||||
batch_size * num_attention_heads, query_length, attn_head_size
|
batch_size * num_attention_heads, query_length, attn_head_size
|
||||||
)
|
)
|
||||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||||
attn_scores = torch.zeros(
|
attn_scores = torch.zeros(
|
||||||
1,
|
1,
|
||||||
dtype=query.dtype,
|
dtype=query.dtype,
|
||||||
|
@ -670,7 +670,7 @@ class FlashCausalLM(Model):
|
|||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
_, batch = self.generate_token(batch)
|
_, batch = self.generate_token(batch)
|
||||||
except Exception as e:
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
|
@ -155,10 +155,7 @@ class EETQLinear(nn.Module):
|
|||||||
device = weight.device
|
device = weight.device
|
||||||
weight = torch.t(weight).contiguous().cpu()
|
weight = torch.t(weight).contiguous().cpu()
|
||||||
weight, scale = quant_weights(weight, torch.int8, False)
|
weight, scale = quant_weights(weight, torch.int8, False)
|
||||||
if bias:
|
|
||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
|
||||||
else:
|
|
||||||
bias = None
|
|
||||||
self.weight = weight.cuda(device)
|
self.weight = weight.cuda(device)
|
||||||
self.scale = scale.cuda(device)
|
self.scale = scale.cuda(device)
|
||||||
self.bias = bias.cuda(device) if bias is not None else None
|
self.bias = bias.cuda(device) if bias is not None else None
|
||||||
|
100
server/text_generation_server/utils/paged_attention.py
Normal file
100
server/text_generation_server/utils/paged_attention.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
# vllm imports
|
||||||
|
from vllm import cache_ops
|
||||||
|
from vllm import attention_ops
|
||||||
|
|
||||||
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_and_cache(key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor,
|
||||||
|
slots: torch.Tensor):
|
||||||
|
cache_ops.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slots
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
):
|
||||||
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
|
# Copyright 2023 The vLLM team. All rights
|
||||||
|
# reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
max_num_partitions = (
|
||||||
|
(max_s + _PARTITION_SIZE - 1) //
|
||||||
|
_PARTITION_SIZE)
|
||||||
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
|
# to parallelize.
|
||||||
|
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||||
|
if use_v1:
|
||||||
|
attention_ops.paged_attention_v1(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Run PagedAttention V2.
|
||||||
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
|
dtype=out.dtype,
|
||||||
|
device=out.device,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=out.device,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
attention_ops.paged_attention_v2(
|
||||||
|
out,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user