Merge branch 'main' into patch-2

This commit is contained in:
Dong Shin 2023-07-20 11:37:46 +09:00 committed by GitHub
commit a1859012c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 680 additions and 352 deletions

8
Cargo.lock generated
View File

@ -2893,7 +2893,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
version = "0.9.2"
version = "0.9.3"
dependencies = [
"average",
"clap",
@ -2913,7 +2913,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "0.9.2"
version = "0.9.3"
dependencies = [
"futures",
"grpc-metadata",
@ -2929,7 +2929,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "0.9.2"
version = "0.9.3"
dependencies = [
"clap",
"ctrlc",
@ -2945,7 +2945,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "0.9.2"
version = "0.9.3"
dependencies = [
"async-stream",
"axum",

View File

@ -8,7 +8,7 @@ members = [
]
[workspace.package]
version = "0.9.2"
version = "0.9.3"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -98,6 +98,16 @@ COPY server/Makefile-flash-att Makefile
# Build specific version of flash attention
RUN make build-flash-attention
# Build Flash Attention v2 CUDA kernels
FROM kernel-builder as flash-att-v2-builder
WORKDIR /usr/src
COPY server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2
RUN make build-flash-attention-v2
# Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder
@ -146,8 +156,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

View File

@ -63,6 +63,8 @@ to power LLMs api-inference widgets.
- [Starcoder](https://huggingface.co/bigcode/starcoder)
- [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b)
- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
- [Llama V2](https://huggingface.co/meta-llama)
Other architectures are supported on a best effort basis using:
@ -132,6 +134,10 @@ print(text)
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
### Using on private models or gated models
You can use `HUGGING_FACE_HUB_TOKEN` environment variable to set the token used by `text-generation-inference` to give access to protected ressources.
### Distributed Tracing
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
@ -211,7 +217,7 @@ sudo apt-get install libssl-dev gcc -y
### CUDA Kernels
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
the kernels by using the `BUILD_EXTENSIONS=False` environment variable.
the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
Be aware that the official Docker image has them enabled by default.

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "0.9.2"
"version": "0.9.3"
},
"paths": {
"/": {

View File

@ -13,7 +13,7 @@ nix = "0.26.2"
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.93"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json"] }
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
[dev-dependencies]
float_eq = "1.0.1"

View File

@ -4,7 +4,7 @@ use nix::unistd::Pid;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
use std::io::{BufRead, BufReader, Read};
use std::io::{BufRead, BufReader, Lines, Read};
use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path;
use std::process::{Child, Command, ExitStatus, Stdio};
@ -15,6 +15,7 @@ use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
use tracing_subscriber::EnvFilter;
mod env_runtime;
@ -41,6 +42,7 @@ impl std::fmt::Display for Quantization {
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype {
Float16,
#[clap(name = "bfloat16")]
BFloat16,
}
@ -182,8 +184,8 @@ struct Args {
/// depends on other parameters like if you're using quantization, flash attention
/// or the model implementation, text-generation-inference cannot infer this number
/// automatically.
#[clap(default_value = "16000", long, env)]
max_batch_total_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
/// This setting defines how many tokens can be passed before forcing the waiting
/// queries to be put on the batch (if the size of the batch allows for it).
@ -265,17 +267,9 @@ struct Args {
#[clap(long, env)]
ngrok_authtoken: Option<String>,
/// ngrok domain name where the axum webserver will be available at
/// ngrok edge
#[clap(long, env)]
ngrok_domain: Option<String>,
/// ngrok basic auth username
#[clap(long, env)]
ngrok_username: Option<String>,
/// ngrok basic auth password
#[clap(long, env)]
ngrok_password: Option<String>,
ngrok_edge: Option<String>,
/// Display a lot of information about your runtime environment
#[clap(long, short, action)]
@ -285,7 +279,7 @@ struct Args {
#[derive(Debug)]
enum ShardStatus {
Ready,
Failed((usize, Option<String>)),
Failed(usize),
}
#[allow(clippy::too_many_arguments)]
@ -310,6 +304,9 @@ fn shard_manager(
shutdown: Arc<AtomicBool>,
_shutdown_sender: mpsc::Sender<()>,
) {
// Enter shard-manager tracing span
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
// Get UDS path
let uds_string = format!("{uds_path}-{rank}");
let uds = Path::new(&uds_string);
@ -364,12 +361,6 @@ fn shard_manager(
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Use cuda allocator. It leads to less memory fragmentation
envs.push((
"PYTORCH_CUDA_ALLOC_CONF".into(),
"backend:cudaMallocAsync".into(),
));
// Torch Distributed Env vars
envs.push(("RANK".into(), rank.to_string().into()));
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
@ -423,7 +414,7 @@ fn shard_manager(
}
// Start process
tracing::info!("Starting shard {rank}");
tracing::info!("Starting shard");
let mut p = match Command::new("text-generation-server")
.args(shard_args)
.envs(envs)
@ -437,30 +428,23 @@ fn shard_manager(
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
} else {
}
{
tracing::error!("{}", err);
}
status_sender
.send(ShardStatus::Failed((rank, Some(err.to_string()))))
.unwrap();
status_sender.send(ShardStatus::Failed(rank)).unwrap();
return;
}
};
// Redirect STDOUT to the console
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
let mut shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
//stdout tracing thread
thread::spawn(move || {
// Enter shard-manager tracing span
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
for line in shard_stdout_reader.lines() {
// Parse loguru logs
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
log.trace();
}
}
log_lines(shard_stdout_reader.lines());
});
let mut ready = false;
@ -469,30 +453,25 @@ fn shard_manager(
loop {
// Process exited
if let Some(exit_status) = p.try_wait().unwrap() {
// We read stderr in another thread as it seems that `read_to_string` can block
// indefinitely in some cases
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
let mut err = String::new();
shard_stderr_reader.read_to_string(&mut err).unwrap();
err_sender.send(err).unwrap_or(());
for line in shard_stderr_reader.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});
let mut err = String::new();
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
err = err + "\n" + &line;
}
let err = err_receiver
.recv_timeout(Duration::from_millis(100))
.map_err(|err| {
tracing::error!("Unable to read shard {rank} error from stderr");
err
})
.ok();
tracing::error!("Shard complete standard error output:\n{err}");
if let Some(signal) = exit_status.signal() {
tracing::error!("Shard process was signaled to shutdown with signal {signal}");
}
status_sender
.send(ShardStatus::Failed((rank, err)))
.unwrap();
status_sender.send(ShardStatus::Failed(rank)).unwrap();
return;
}
@ -500,17 +479,17 @@ fn shard_manager(
if shutdown.load(Ordering::SeqCst) {
p.kill().unwrap();
let _ = p.wait();
tracing::info!("Shard {rank} terminated");
tracing::info!("Shard terminated");
return;
}
// Shard is ready
if uds.exists() && !ready {
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
tracing::info!("Shard ready in {:?}", start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap();
ready = true;
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard {rank} to be ready...");
tracing::info!("Waiting for shard to be ready...");
wait_time = Instant::now();
}
sleep(Duration::from_millis(100));
@ -579,6 +558,23 @@ impl PythonLogMessage {
}
}
impl TryFrom<&String> for PythonLogMessage {
type Error = serde_json::Error;
fn try_from(value: &String) -> Result<Self, Self::Error> {
serde_json::from_str::<Self>(value)
}
}
fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
for line in lines.flatten() {
match PythonLogMessage::try_from(&line) {
Ok(log) => log.trace(),
Err(_) => tracing::debug!("{line}"),
}
}
}
fn find_num_shards(
sharded: Option<bool>,
num_shard: Option<usize>,
@ -632,6 +628,9 @@ enum LauncherError {
}
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
// Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
let mut download_args = vec![
"download-weights".to_string(),
args.model_id.to_string(),
@ -693,6 +692,8 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
} else {
tracing::error!("{}", err);
}
return Err(LauncherError::DownloadError);
@ -701,16 +702,10 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Redirect STDOUT to the console
let download_stdout = download_process.stdout.take().unwrap();
thread::spawn(move || {
// Enter download tracing span
let stdout = BufReader::new(download_stdout);
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
for line in stdout.lines() {
// Parse loguru logs
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
log.trace();
}
}
thread::spawn(move || {
log_lines(stdout.lines());
});
loop {
@ -815,11 +810,8 @@ fn spawn_shards(
Err(TryRecvError::Empty) => {
sleep(Duration::from_millis(100));
}
Ok(ShardStatus::Failed((rank, err))) => {
Ok(ShardStatus::Failed(rank)) => {
tracing::error!("Shard {rank} failed to start");
if let Some(err) = err {
tracing::error!("{err}");
}
shutdown_shards(shutdown, shutdown_receiver);
return Err(LauncherError::ShardCannotStart);
}
@ -854,8 +846,6 @@ fn spawn_webserver(
args.max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(),
args.max_batch_prefill_tokens.to_string(),
"--max-batch-total-tokens".to_string(),
args.max_batch_total_tokens.to_string(),
"--waiting-served-ratio".to_string(),
args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(),
@ -872,6 +862,12 @@ fn spawn_webserver(
args.model_id,
];
// Model optional max batch total tokens
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
router_args.push("--max-batch-total-tokens".to_string());
router_args.push(max_batch_total_tokens.to_string());
}
// Model optional revision
if let Some(ref revision) = args.revision {
router_args.push("--revision".to_string());
@ -896,26 +892,11 @@ fn spawn_webserver(
// Ngrok
if args.ngrok {
let authtoken = args.ngrok_authtoken.ok_or_else(|| {
tracing::error!("`ngrok-authtoken` must be set when using ngrok tunneling");
LauncherError::WebserverCannotStart
})?;
router_args.push("--ngrok".to_string());
router_args.push("--ngrok-authtoken".to_string());
router_args.push(authtoken);
if let Some(domain) = args.ngrok_domain {
router_args.push("--ngrok-domain".to_string());
router_args.push(domain);
}
if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) {
router_args.push("--ngrok-username".to_string());
router_args.push(username);
router_args.push("--ngrok-password".to_string());
router_args.push(password);
}
router_args.push(args.ngrok_authtoken.unwrap());
router_args.push("--ngrok-edge".to_string());
router_args.push(args.ngrok_edge.unwrap());
}
// Copy current process env
@ -993,12 +974,22 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
fn main() -> Result<(), LauncherError> {
// Pattern match configuration
let args = Args::parse();
let args: Args = Args::parse();
// Filter events with LOG_LEVEL
let env_filter =
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
if args.json_output {
tracing_subscriber::fmt().json().init();
tracing_subscriber::fmt()
.with_env_filter(env_filter)
.json()
.init();
} else {
tracing_subscriber::fmt().compact().init();
tracing_subscriber::fmt()
.with_env_filter(env_filter)
.compact()
.init();
}
if args.env {
@ -1020,18 +1011,7 @@ fn main() -> Result<(), LauncherError> {
args.max_batch_prefill_tokens, args.max_input_length
)));
}
if args.max_batch_prefill_tokens > args.max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, args.max_batch_total_tokens
)));
}
if args.max_total_tokens as u32 > args.max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, args.max_batch_total_tokens
)));
}
if args.validation_workers == 0 {
return Err(LauncherError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
@ -1049,6 +1029,35 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("Sharding model on {num_shard} processes");
}
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if args.max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, max_batch_total_tokens
)));
}
}
if args.ngrok {
if args.ngrok_authtoken.is_none() {
return Err(LauncherError::ArgumentValidation(
"`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
));
}
if args.ngrok_edge.is_none() {
return Err(LauncherError::ArgumentValidation(
"`ngrok-edge` must be set when using ngrok tunneling".to_string(),
));
}
}
// Signal handler
let running = Arc::new(AtomicBool::new(true));
let r = running.clone();
@ -1101,11 +1110,8 @@ fn main() -> Result<(), LauncherError> {
let mut exit_code = Ok(());
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
tracing::error!("Shard {rank} crashed");
if let Some(err) = err {
tracing::error!("{err}");
}
exit_code = Err(LauncherError::ShardFailed);
break;
};

View File

@ -198,9 +198,10 @@ message DecodeResponse {
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
/// Maximum number of tokens that the client will send
uint32 max_total_tokens = 2;
}
/// Empty response
message WarmupResponse {}
message WarmupResponse {
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;
}

View File

@ -103,8 +103,7 @@ impl Client {
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let mut requests = Vec::new();
@ -143,13 +142,9 @@ impl Client {
max_tokens: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_total_tokens,
})
.inject_context();
self.stub.warmup(request).await?.into_inner();
Ok(())
let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}
/// Generate one token for each request in the given batch

View File

@ -95,14 +95,11 @@ impl ShardedClient {
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
) -> Result<Option<u32>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
})
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()

View File

@ -53,7 +53,7 @@ impl Infer {
generation_health: Arc<AtomicBool>,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding);
let queue = Queue::new(requires_padding, 16);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});

View File

@ -37,8 +37,8 @@ struct Args {
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(default_value = "16000", long, env)]
max_batch_total_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "0.0.0.0", long, env)]
@ -64,11 +64,7 @@ struct Args {
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_domain: Option<String>,
#[clap(long, env)]
ngrok_username: Option<String>,
#[clap(long, env)]
ngrok_password: Option<String>,
ngrok_edge: Option<String>,
}
fn main() -> Result<(), RouterError> {
@ -96,9 +92,7 @@ fn main() -> Result<(), RouterError> {
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_domain,
ngrok_username,
ngrok_password,
ngrok_edge,
} = args;
// Validate args
@ -110,18 +104,22 @@ fn main() -> Result<(), RouterError> {
if max_input_length as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
}
if max_batch_prefill_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
@ -210,14 +208,35 @@ fn main() -> Result<(), RouterError> {
// Warmup model
tracing::info!("Warming up model");
sharded_client
.warmup(
max_input_length as u32,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
let max_supported_batch_total_tokens = match sharded_client
.warmup(max_input_length as u32, max_batch_prefill_tokens)
.await
.map_err(RouterError::Warmup)?;
.map_err(RouterError::Warmup)?
{
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
);
tracing::warn!("Model does not support automatic max batch total tokens");
max_batch_total_tokens
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
max_supported_batch_total_tokens
}
};
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
tracing::info!("Connected");
let addr = match hostname.parse() {
@ -240,7 +259,7 @@ fn main() -> Result<(), RouterError> {
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_supported_batch_total_tokens,
max_waiting_tokens,
sharded_client,
tokenizer,
@ -249,9 +268,7 @@ fn main() -> Result<(), RouterError> {
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_domain,
ngrok_username,
ngrok_password,
ngrok_edge,
)
.await?;
Ok(())

View File

@ -33,12 +33,12 @@ pub(crate) struct Queue {
}
impl Queue {
pub(crate) fn new(requires_padding: bool) -> Self {
pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self {
// Create channel
let (queue_sender, queue_receiver) = flume::unbounded();
// Launch background queue task
tokio::spawn(queue_task(requires_padding, queue_receiver));
tokio::spawn(queue_task(requires_padding, block_size, queue_receiver));
Self { queue_sender }
}
@ -81,8 +81,12 @@ impl Queue {
}
// Background task responsible of the queue state
async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueCommand>) {
let mut state = State::new(requires_padding);
async fn queue_task(
requires_padding: bool,
block_size: u32,
receiver: flume::Receiver<QueueCommand>,
) {
let mut state = State::new(requires_padding, block_size);
while let Ok(cmd) = receiver.recv_async().await {
match cmd {
@ -119,15 +123,19 @@ struct State {
/// Whether the model is using padding
requires_padding: bool,
/// Paged Attention block size
block_size: u32,
}
impl State {
fn new(requires_padding: bool) -> Self {
fn new(requires_padding: bool, block_size: u32) -> Self {
Self {
entries: VecDeque::with_capacity(128),
next_id: 0,
next_batch_id: 0,
requires_padding,
block_size,
}
}
@ -187,10 +195,21 @@ impl State {
max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
} else {
prefill_tokens += entry.request.input_length;
// pad to block size
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
/ self.block_size)
* self.block_size;
}
if self.requires_padding {
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
} else {
// pad to block size
decode_tokens +=
((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1)
/ self.block_size)
* self.block_size;
}
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens) > token_budget
@ -321,7 +340,7 @@ mod tests {
#[test]
fn test_append() {
let mut state = State::new(false);
let mut state = State::new(false, 1);
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
@ -337,7 +356,7 @@ mod tests {
#[test]
fn test_next_batch_empty() {
let mut state = State::new(false);
let mut state = State::new(false, 1);
assert!(state.next_batch(None, 1, 1).is_none());
assert!(state.next_batch(Some(1), 1, 1).is_none());
@ -345,7 +364,7 @@ mod tests {
#[test]
fn test_next_batch_min_size() {
let mut state = State::new(false);
let mut state = State::new(false, 1);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
@ -377,7 +396,7 @@ mod tests {
#[test]
fn test_next_batch_token_budget() {
let mut state = State::new(false);
let mut state = State::new(false, 1);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
@ -410,14 +429,14 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new(false);
let queue = Queue::new(false, 1);
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false);
let queue = Queue::new(false, 1);
assert!(queue.next_batch(None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
@ -425,7 +444,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false);
let queue = Queue::new(false, 1);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -458,7 +477,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false);
let queue = Queue::new(false, 1);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -483,7 +502,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false);
let queue = Queue::new(false, 1);
let (entry, _) = default_entry();
queue.append(entry);

View File

@ -524,9 +524,7 @@ pub async fn run(
allow_origin: Option<AllowOrigin>,
ngrok: bool,
ngrok_authtoken: Option<String>,
ngrok_domain: Option<String>,
ngrok_username: Option<String>,
ngrok_password: Option<String>,
ngrok_edge: Option<String>,
) -> Result<(), axum::BoxError> {
// OpenAPI documentation
#[derive(OpenApi)]
@ -696,32 +694,25 @@ pub async fn run(
#[cfg(feature = "ngrok")]
{
use ngrok::config::TunnelBuilder;
use ngrok::tunnel::UrlTunnel;
let _ = addr;
let authtoken =
ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling");
let mut tunnel = ngrok::Session::builder()
let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling");
let tunnel = ngrok::Session::builder()
.authtoken(authtoken)
.connect()
.await
.unwrap()
.http_endpoint();
if let Some(domain) = ngrok_domain {
tunnel = tunnel.domain(domain);
}
if let (Some(username), Some(password)) = (ngrok_username, ngrok_password) {
tunnel = tunnel.basic_auth(username, password);
}
.labeled_tunnel()
.label("edge", edge);
let listener = tunnel.listen().await.unwrap();
// Run server
tracing::info!("Ingress URL: {:?}", listener.url());
axum::Server::builder(listener)
.serve(app.into_make_service())
//Wait until all requests are finished to shut down

View File

@ -1,4 +1,5 @@
include Makefile-flash-att
include Makefile-flash-att-v2
include Makefile-vllm
unit-tests:

View File

@ -0,0 +1,13 @@
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
flash-attention-v2:
# Clone flash attention
pip install packaging
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
build-flash-attention-v2: flash-attention-v2
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
cd flash-attention-v2 && python setup.py build
install-flash-attention-v2: build-flash-attention-v2
cd flash-attention-v2 && python setup.py install

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-server"
version = "0.9.2"
version = "0.9.3"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -194,6 +194,8 @@ def quantize(
percdamp: float = 0.01,
act_order: bool = False,
):
if revision is None:
revision = "main"
download_weights(
model_id=model_id,
revision=revision,
@ -207,6 +209,7 @@ def quantize(
bits=4,
groupsize=128,
output_dir=output_dir,
revision=revision,
trust_remote_code=trust_remote_code,
upload_to_model_id=upload_to_model_id,
percdamp=percdamp,

View File

@ -42,35 +42,10 @@ __all__ = [
"get_model",
]
FLASH_ATT_ERROR_MESSAGE = (
"{} requires CUDA and Flash Attention kernels to be installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = True
try:
if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
if not torch.cuda.is_available():
FLASH_ATT_ERROR_MESSAGE = (
"{} requires CUDA. No compatible CUDA devices found."
)
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
supported = is_sm75 or is_sm8x or is_sm90
if not supported:
FLASH_ATT_ERROR_MESSAGE = (
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
"No compatible CUDA device found."
)
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
)
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
@ -80,13 +55,8 @@ try:
FlashSantacoderSharded,
)
FLASH_ATTENTION = True
else:
FLASH_ATTENTION = False
except ImportError:
logger.opt(exception=True).warning(
"Could not import Flash Attention enabled models"
)
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False
if FLASH_ATTENTION:

View File

@ -23,25 +23,77 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
import dropout_layer_norm
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
)
class LlamaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class LlamaRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
@ -59,7 +111,8 @@ class LlamaRMSNorm(nn.Module):
hidden_states += residual
residual = hidden_states
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
@ -94,6 +147,27 @@ class LlamaRMSNorm(nn.Module):
return normed_hidden_states, res
def _load_gqa(config, prefix: str, weights):
w = [
weights.get_sharded(f"{prefix}.q_proj.weight", dim=0),
weights.get_sharded(f"{prefix}.k_proj.weight", dim=0),
weights.get_sharded(f"{prefix}.v_proj.weight", dim=0),
]
weight = torch.cat(w, dim=0)
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
bias = None
assert config.hidden_size % config.num_attention_heads == 0
head_size = config.hidden_size // config.num_attention_heads
assert config.num_attention_heads % weights.process_group.size() == 0
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
@ -118,6 +192,12 @@ class FlashLlamaAttention(torch.nn.Module):
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if config.num_attention_heads != config.num_key_value_heads:
self.query_key_value = _load_gqa(config, prefix, weights)
else:
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
@ -131,9 +211,10 @@ class FlashLlamaAttention(torch.nn.Module):
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
@ -148,38 +229,37 @@ class FlashLlamaAttention(torch.nn.Module):
max_s,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
@ -187,7 +267,7 @@ class FlashLlamaAttention(torch.nn.Module):
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
qkv[:, 0],
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
@ -324,6 +404,7 @@ class FlashLlamaModel(torch.nn.Module):
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,

View File

@ -27,13 +27,11 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -153,22 +151,14 @@ class FlashNeoxAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:

View File

@ -6,13 +6,11 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -182,27 +180,15 @@ class FlashRWAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
if self.num_heads_kv == 1:
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
# flash attention
flash_attn_cuda.fwd(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
@ -314,30 +300,15 @@ class FlashRWLargeAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# Expand to query shape
kv = (
kv.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# flash attention
flash_attn_cuda.fwd(
attention(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:

View File

@ -5,13 +5,11 @@ from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -271,26 +269,15 @@ class FlashMQAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
# flash attention
flash_attn_cuda.fwd(
attention(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:

View File

@ -710,14 +710,14 @@ class FlashCausalLM(Model):
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
def warmup(self, batch: FlashCausalLMBatch):
global CACHE_MANAGER
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(self.device)
try:
CACHE_MANAGER = CacheManager(
# Adds some wiggle room
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
batch.blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
@ -727,11 +727,43 @@ class FlashCausalLM(Model):
_, batch = self.generate_token(batch)
except Exception as e:
raise RuntimeError(
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
f"prefill tokens. "
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`"
) from e
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize(self.device)
peak_memory = torch.cuda.max_memory_reserved(self.device)
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
# 0.98 to add some wiggle room
num_blocks = (
int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ batch.blocks
)
del CACHE_MANAGER
del batch
torch.cuda.empty_cache()
CACHE_MANAGER = CacheManager(
num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.device,
)
return int(num_blocks * BLOCK_SIZE)
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode(
@ -991,7 +1023,6 @@ class FlashCausalLM(Model):
if stopped:
del batch
torch.cuda.empty_cache()
# No need to return a batch if we know that all requests stopped
return generations, None

View File

@ -9,6 +9,7 @@ from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
LlamaConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
@ -52,7 +53,7 @@ class FlashLlama(FlashCausalLM):
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
config = LlamaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
@ -69,7 +70,7 @@ class FlashLlama(FlashCausalLM):
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,

View File

@ -58,8 +58,9 @@ class Model(ABC):
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError
def warmup(self, batch: B, max_total_tokens: int):
def warmup(self, batch: B) -> Optional[int]:
self.generate_token(batch)
return None
def decode_token(
self,

View File

@ -51,21 +51,17 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
filtered_batch = batch.filter(request.request_ids)
self.cache.set(filtered_batch)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
self.model.warmup(batch, request.max_total_tokens)
max_supported_total_tokens = self.model.warmup(batch)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return generate_pb2.WarmupResponse()
return generate_pb2.WarmupResponse(
max_supported_total_tokens=max_supported_total_tokens
)
async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb(
@ -96,8 +92,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches)
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
batch = batches[0]

View File

@ -0,0 +1,124 @@
import os
import torch
from loguru import logger
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2 = False
try:
try:
import flash_attn_2_cuda
except ImportError:
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
)
if not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2 = True
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
):
if HAS_FLASH_ATTN_V2:
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
if HAS_FLASH_ATTN:
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
return flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)
raise NotImplementedError("flash attention is not installed")

View File

@ -13,6 +13,9 @@ import transformers
from huggingface_hub import HfApi
import numpy as np
import torch
from accelerate import init_empty_weights
from text_generation_server.utils import initialize_torch_distributed, Weights
from text_generation_server.utils.hub import weight_files
from text_generation_server.utils.gptq.quant_linear import QuantLinear
from loguru import logger
from typing import Optional
@ -38,7 +41,6 @@ class Quantizer(nn.Module):
maxshrink=0.8,
trits=False,
):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
@ -600,6 +602,8 @@ def sequential(
nsamples,
bits,
groupsize,
*,
hooks,
percdamp=0.01,
sym: bool = False,
act_order: bool = False,
@ -637,7 +641,7 @@ def sequential(
layers[0] = Catcher(layers[0])
for batch in dataloader:
try:
model(batch[0])
model(batch[0].cuda())
except ValueError:
pass
layers[0] = layers[0].module
@ -646,6 +650,8 @@ def sequential(
# model.model.embed_tokens = model.model.embed_tokens.cpu()
# model.model.norm = model.model.norm.cpu()
torch.cuda.empty_cache()
for hook in hooks:
hook.remove()
outs = torch.zeros_like(inps)
@ -662,10 +668,8 @@ def sequential(
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
print("+==================+==============+============+===========+=======+")
from accelerate.hooks import remove_hook_from_submodules
layer = layers[i].to(dev)
remove_hook_from_submodules(layer)
layer = layers[i]
layer.load()
full = find_layers(layer)
sequential = [list(full.keys())]
@ -677,6 +681,7 @@ def sequential(
gptq[name].quantizer.configure(
bits, perchannel=True, sym=sym, mse=False
)
pass
def add_batch(name):
def tmp(_, inp, out):
@ -688,7 +693,6 @@ def sequential(
for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
for h in handles:
h.remove()
@ -714,7 +718,7 @@ def sequential(
for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
layers[i] = layer.cpu()
layer.unload()
del layer
del gptq
torch.cuda.empty_cache()
@ -768,24 +772,136 @@ def pack(model, quantizers, bits, groupsize):
return model
def setdeepattr(module, full_name, tensor):
current = module
tokens = full_name.split(".")
for token in tokens[:-1]:
current = getattr(current, token)
setattr(current, tokens[-1], tensor)
def getdeepattr(module, full_name):
current = module
tokens = full_name.split(".")
for token in tokens:
current = getattr(current, token)
return current
def load_weights_pre_hook(module_name, weights, recursive=False):
def inner(module, args):
print(f"Pre hook {module_name}")
local_params = {}
for k, v in module.named_parameters():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for k, v in module.named_buffers():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for local_param in local_params:
current_tensor = getdeepattr(module, local_param)
if current_tensor.device == torch.device("meta"):
# print(f"Loading {local_param}")
if module_name:
tensor_name = f"{module_name}.{local_param}"
else:
tensor_name = local_param
tensor = weights.get_tensor(tensor_name)
setdeepattr(module, local_param, nn.Parameter(tensor))
else:
setdeepattr(
module,
local_param,
nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))),
)
return inner
def load_weights_post_hook(module_name, weights, recursive=False):
def inner(module, args, output):
print(f"Post hook {module_name}")
local_params = {}
for k, v in module.named_parameters():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for k, v in module.named_buffers():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for local_param in local_params:
# print(f"Unloading {local_param}")
current_tensor = getdeepattr(module, local_param)
setdeepattr(
module,
local_param,
nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
)
return output
return inner
def quantize(
model_id: str,
bits: int,
groupsize: int,
output_dir: str,
revision: str,
trust_remote_code: bool,
upload_to_model_id: Optional[str],
percdamp: float,
act_order: bool,
):
print("loading model")
model = AutoModelForCausalLM.from_pretrained(
config = AutoConfig.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="balanced_low_0",
trust_remote_code=trust_remote_code,
)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)
model = model.eval()
print("LOADED model")
files = weight_files(model_id, revision, extension=".safetensors")
process_group, _, _ = initialize_torch_distributed()
weights = Weights(
files,
device=torch.device("cuda:0"),
dtype=torch.float16,
process_group=process_group,
aliases={"embed_tokens.weight": ["lm_head.weight"]},
)
hooks = []
for name, module in model.named_modules():
def load(module, name):
def _load():
load_weights_pre_hook(name, weights, recursive=True)(module, None)
return _load
def unload(module, name):
def _unload():
load_weights_post_hook(name, weights, recursive=True)(
module, None, None
)
return _unload
module.load = load(module, name)
module.unload = unload(module, name)
hooks.append(
module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
)
hooks.append(
module.register_forward_hook(load_weights_post_hook(name, weights))
)
model.seqlen = 2048
dataset = "wikitext2"
@ -806,6 +922,7 @@ def quantize(
groupsize,
percdamp=percdamp,
act_order=act_order,
hooks=hooks,
)
print(time.time() - tick)
@ -858,7 +975,6 @@ def quantize(
logger.info("Saved tokenizer")
if upload_to_model_id:
api = HfApi()
api.upload_folder(